From e9ae67ff0721296cd1edfc44e6461a59c7632a25 Mon Sep 17 00:00:00 2001 From: Jon Dufresne Date: Tue, 4 Feb 2020 05:07:58 -0800 Subject: [PATCH 1/8] Improve resource ownership semantics of _solve_conn_curs() A new cursor is always creates so always close it once finished. --- lib/_json.py | 27 +++++++------- lib/_range.py | 51 +++++++++++++------------- lib/extras.py | 73 +++++++++++++++++++------------------- tests/test_types_extras.py | 28 ++++++++++++++- 4 files changed, 102 insertions(+), 77 deletions(-) diff --git a/lib/_json.py b/lib/_json.py index eac37972..9dcc227c 100644 --- a/lib/_json.py +++ b/lib/_json.py @@ -180,23 +180,22 @@ def _get_json_oids(conn_or_curs, name='json'): from psycopg2.extensions import STATUS_IN_TRANSACTION from psycopg2.extras import _solve_conn_curs - conn, curs = _solve_conn_curs(conn_or_curs) + with _solve_conn_curs(conn_or_curs) as (conn, curs): + # Store the transaction status of the connection to revert it after use + conn_status = conn.status - # Store the transaction status of the connection to revert it after use - conn_status = conn.status + # column typarray not available before PG 8.3 + typarray = conn.info.server_version >= 80300 and "typarray" or "NULL" - # column typarray not available before PG 8.3 - typarray = conn.info.server_version >= 80300 and "typarray" or "NULL" + # get the oid for the hstore + curs.execute( + "SELECT t.oid, %s FROM pg_type t WHERE t.typname = %%s;" + % typarray, (name,)) + r = curs.fetchone() - # get the oid for the hstore - curs.execute( - "SELECT t.oid, %s FROM pg_type t WHERE t.typname = %%s;" - % typarray, (name,)) - r = curs.fetchone() - - # revert the status of the connection as before the command - if conn_status != STATUS_IN_TRANSACTION and not conn.autocommit: - conn.rollback() + # revert the status of the connection as before the command + if conn_status != STATUS_IN_TRANSACTION and not conn.autocommit: + conn.rollback() if not r: raise conn.ProgrammingError("%s data type not found" % name) diff --git a/lib/_range.py b/lib/_range.py index b668fb63..607d85b0 100644 --- a/lib/_range.py +++ b/lib/_range.py @@ -351,25 +351,24 @@ class RangeCaster(object): """ from psycopg2.extensions import STATUS_IN_TRANSACTION from psycopg2.extras import _solve_conn_curs - conn, curs = _solve_conn_curs(conn_or_curs) + with _solve_conn_curs(conn_or_curs) as (conn, curs): + if conn.info.server_version < 90200: + raise ProgrammingError("range types not available in version %s" + % conn.info.server_version) - if conn.info.server_version < 90200: - raise ProgrammingError("range types not available in version %s" - % conn.info.server_version) + # Store the transaction status of the connection to revert it after use + conn_status = conn.status - # Store the transaction status of the connection to revert it after use - conn_status = conn.status + # Use the correct schema + if '.' in name: + schema, tname = name.split('.', 1) + else: + tname = name + schema = 'public' - # Use the correct schema - if '.' in name: - schema, tname = name.split('.', 1) - else: - tname = name - schema = 'public' - - # get the type oid and attributes - try: - curs.execute("""\ + # get the type oid and attributes + try: + curs.execute("""\ select rngtypid, rngsubtype, (select typarray from pg_type where oid = rngtypid) from pg_range r @@ -378,17 +377,17 @@ join pg_namespace ns on ns.oid = typnamespace where typname = %s and ns.nspname = %s; """, (tname, schema)) - except ProgrammingError: - if not conn.autocommit: - conn.rollback() - raise - else: - rec = curs.fetchone() + except ProgrammingError: + if not conn.autocommit: + conn.rollback() + raise + else: + rec = curs.fetchone() - # revert the status of the connection as before the command - if (conn_status != STATUS_IN_TRANSACTION - and not conn.autocommit): - conn.rollback() + # revert the status of the connection as before the command + if (conn_status != STATUS_IN_TRANSACTION + and not conn.autocommit): + conn.rollback() if not rec: raise ProgrammingError( diff --git a/lib/extras.py b/lib/extras.py index 236a5b7e..6c378bb2 100644 --- a/lib/extras.py +++ b/lib/extras.py @@ -26,6 +26,7 @@ and classes until a better place in the distribution is found. # FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public # License for more details. +import contextlib import os as _os import time as _time import re as _re @@ -781,6 +782,7 @@ def wait_select(conn): continue +@contextlib.contextmanager def _solve_conn_curs(conn_or_curs): """Return the connection and a DBAPI cursor from a connection or cursor.""" if conn_or_curs is None: @@ -793,7 +795,8 @@ def _solve_conn_curs(conn_or_curs): conn = conn_or_curs curs = conn.cursor(cursor_factory=_cursor) - return conn, curs + with curs: + yield conn, curs class HstoreAdapter(object): @@ -904,31 +907,30 @@ class HstoreAdapter(object): def get_oids(self, conn_or_curs): """Return the lists of OID of the hstore and hstore[] types. """ - conn, curs = _solve_conn_curs(conn_or_curs) + with _solve_conn_curs(conn_or_curs) as (conn, curs): + # Store the transaction status of the connection to revert it after use + conn_status = conn.status - # Store the transaction status of the connection to revert it after use - conn_status = conn.status + # column typarray not available before PG 8.3 + typarray = conn.info.server_version >= 80300 and "typarray" or "NULL" - # column typarray not available before PG 8.3 - typarray = conn.info.server_version >= 80300 and "typarray" or "NULL" + rv0, rv1 = [], [] - rv0, rv1 = [], [] - - # get the oid for the hstore - curs.execute("""\ + # get the oid for the hstore + curs.execute("""\ SELECT t.oid, %s FROM pg_type t JOIN pg_namespace ns ON typnamespace = ns.oid WHERE typname = 'hstore'; """ % typarray) - for oids in curs: - rv0.append(oids[0]) - rv1.append(oids[1]) + for oids in curs: + rv0.append(oids[0]) + rv1.append(oids[1]) - # revert the status of the connection as before the command - if (conn_status != _ext.STATUS_IN_TRANSACTION - and not conn.autocommit): - conn.rollback() + # revert the status of the connection as before the command + if (conn_status != _ext.STATUS_IN_TRANSACTION + and not conn.autocommit): + conn.rollback() return tuple(rv0), tuple(rv1) @@ -1083,23 +1085,22 @@ class CompositeCaster(object): Raise `ProgrammingError` if the type is not found. """ - conn, curs = _solve_conn_curs(conn_or_curs) + with _solve_conn_curs(conn_or_curs) as (conn, curs): + # Store the transaction status of the connection to revert it after use + conn_status = conn.status - # Store the transaction status of the connection to revert it after use - conn_status = conn.status + # Use the correct schema + if '.' in name: + schema, tname = name.split('.', 1) + else: + tname = name + schema = 'public' - # Use the correct schema - if '.' in name: - schema, tname = name.split('.', 1) - else: - tname = name - schema = 'public' + # column typarray not available before PG 8.3 + typarray = conn.info.server_version >= 80300 and "typarray" or "NULL" - # column typarray not available before PG 8.3 - typarray = conn.info.server_version >= 80300 and "typarray" or "NULL" - - # get the type oid and attributes - curs.execute("""\ + # get the type oid and attributes + curs.execute("""\ SELECT t.oid, %s, attname, atttypid FROM pg_type t JOIN pg_namespace ns ON typnamespace = ns.oid @@ -1109,12 +1110,12 @@ WHERE typname = %%s AND nspname = %%s ORDER BY attnum; """ % typarray, (tname, schema)) - recs = curs.fetchall() + recs = curs.fetchall() - # revert the status of the connection as before the command - if (conn_status != _ext.STATUS_IN_TRANSACTION - and not conn.autocommit): - conn.rollback() + # revert the status of the connection as before the command + if (conn_status != _ext.STATUS_IN_TRANSACTION + and not conn.autocommit): + conn.rollback() if not recs: raise psycopg2.ProgrammingError( diff --git a/tests/test_types_extras.py b/tests/test_types_extras.py index 91e4a8ea..e4828077 100755 --- a/tests/test_types_extras.py +++ b/tests/test_types_extras.py @@ -36,7 +36,7 @@ from psycopg2._json import _get_json_oids from psycopg2.extras import ( CompositeCaster, DateRange, DateTimeRange, DateTimeTZRange, HstoreAdapter, Inet, Json, NumericRange, Range, RealDictConnection, - register_composite, register_hstore, register_range, + register_composite, register_hstore, register_range, _solve_conn_curs ) from psycopg2.tz import FixedOffsetTimezone @@ -1624,6 +1624,32 @@ class RangeCasterTestCase(ConnectingTestCase): cur.execute("rollback to savepoint x;") +class TestSolveConnCurs(ConnectingTestCase): + def test_pass_connection(self): + with _solve_conn_curs(self.conn) as (conn, curs): + self.assertIsInstance(conn, psycopg2.extensions.connection) + self.assertIsInstance(curs, psycopg2.extensions.cursor) + self.assertIs(conn, self.conn) + self.assertFalse(conn.closed) + self.assertFalse(curs.closed) + self.assertFalse(conn.closed) + self.assertTrue(curs.closed) + + def test_pass_cursor(self): + with self.conn.cursor() as cursor: + with _solve_conn_curs(cursor) as (conn, curs): + self.assertIsInstance(conn, psycopg2.extensions.connection) + self.assertIsInstance(curs, psycopg2.extensions.cursor) + self.assertIs(conn, self.conn) + self.assertIsNot(curs, cursor) + self.assertFalse(conn.closed) + self.assertFalse(curs.closed) + self.assertFalse(conn.closed) + self.assertTrue(curs.closed) + self.assertFalse(conn.closed) + self.assertTrue(curs.closed) + + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__) From 6b63fae20a94d3538ff4bacba1408444fb78aaca Mon Sep 17 00:00:00 2001 From: Jon Dufresne Date: Sat, 1 Feb 2020 08:11:09 -0800 Subject: [PATCH 2/8] Always close connection objects in tests --- tests/dbapi20.py | 2 ++ tests/dbapi20_tpc.py | 25 ++++++++++++++----------- tests/test_async.py | 4 +++- tests/test_async_keyword.py | 3 +++ tests/test_cancel.py | 1 + tests/testutils.py | 18 ++++++++++-------- 6 files changed, 33 insertions(+), 20 deletions(-) diff --git a/tests/dbapi20.py b/tests/dbapi20.py index fe89bb0e..5209580f 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -232,6 +232,7 @@ class DatabaseAPI20Test(unittest.TestCase): self.failUnless(con.InternalError is drv.InternalError) self.failUnless(con.ProgrammingError is drv.ProgrammingError) self.failUnless(con.NotSupportedError is drv.NotSupportedError) + con.close() def test_commit(self): @@ -251,6 +252,7 @@ class DatabaseAPI20Test(unittest.TestCase): con.rollback() except self.driver.NotSupportedError: pass + con.close() def test_cursor(self): con = self._connect() diff --git a/tests/dbapi20_tpc.py b/tests/dbapi20_tpc.py index d4790f71..c6c87b24 100644 --- a/tests/dbapi20_tpc.py +++ b/tests/dbapi20_tpc.py @@ -24,19 +24,22 @@ class TwoPhaseCommitTests(unittest.TestCase): def test_xid(self): con = self.connect() try: - xid = con.xid(42, "global", "bqual") - except self.driver.NotSupportedError: - self.fail("Driver does not support transaction IDs.") + try: + xid = con.xid(42, "global", "bqual") + except self.driver.NotSupportedError: + self.fail("Driver does not support transaction IDs.") - self.assertEquals(xid[0], 42) - self.assertEquals(xid[1], "global") - self.assertEquals(xid[2], "bqual") + self.assertEquals(xid[0], 42) + self.assertEquals(xid[1], "global") + self.assertEquals(xid[2], "bqual") - # Try some extremes for the transaction ID: - xid = con.xid(0, "", "") - self.assertEquals(tuple(xid), (0, "", "")) - xid = con.xid(0x7fffffff, "a" * 64, "b" * 64) - self.assertEquals(tuple(xid), (0x7fffffff, "a" * 64, "b" * 64)) + # Try some extremes for the transaction ID: + xid = con.xid(0, "", "") + self.assertEquals(tuple(xid), (0, "", "")) + xid = con.xid(0x7fffffff, "a" * 64, "b" * 64) + self.assertEquals(tuple(xid), (0x7fffffff, "a" * 64, "b" * 64)) + finally: + con.close() def test_tpc_begin(self): con = self.connect() diff --git a/tests/test_async.py b/tests/test_async.py index d62eb3b0..8624be4a 100755 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -451,14 +451,16 @@ class AsyncTests(ConnectingTestCase): self.assertEqual(cur.fetchone(), (42,)) def test_async_connection_error_message(self): + cnn = psycopg2.connect('dbname=thisdatabasedoesntexist', async_=True) try: - cnn = psycopg2.connect('dbname=thisdatabasedoesntexist', async_=True) self.wait(cnn) except psycopg2.Error as e: self.assertNotEqual(str(e), "asynchronous connection failed", "connection error reason lost") else: self.fail("no exception raised") + finally: + cnn.close() @skip_before_postgres(8, 2) def test_copy_no_hang(self): diff --git a/tests/test_async_keyword.py b/tests/test_async_keyword.py index e1126928..f8e50afe 100755 --- a/tests/test_async_keyword.py +++ b/tests/test_async_keyword.py @@ -89,6 +89,8 @@ class AsyncTests(ConnectingTestCase): "connection error reason lost") else: self.fail("no exception raised") + finally: + cnn.close() class CancelTests(ConnectingTestCase): @@ -118,6 +120,7 @@ class CancelTests(ConnectingTestCase): cur.execute("select 1") extras.wait_select(async_conn) self.assertEqual(cur.fetchall(), [(1, )]) + async_conn.close() def test_async_connection_cancel(self): async_conn = psycopg2.connect(dsn, async=True) diff --git a/tests/test_cancel.py b/tests/test_cancel.py index 4c60c0b7..06477edc 100755 --- a/tests/test_cancel.py +++ b/tests/test_cancel.py @@ -105,6 +105,7 @@ class CancelTests(ConnectingTestCase): cur.execute("select 1") extras.wait_select(async_conn) self.assertEqual(cur.fetchall(), [(1, )]) + async_conn.close() def test_async_connection_cancel(self): async_conn = psycopg2.connect(dsn, async_=True) diff --git a/tests/testutils.py b/tests/testutils.py index 26f6cc71..cd69c468 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -248,15 +248,17 @@ def skip_if_tpc_disabled(f): @wraps(f) def skip_if_tpc_disabled_(self): cnn = self.connect() - cur = cnn.cursor() try: - cur.execute("SHOW max_prepared_transactions;") - except psycopg2.ProgrammingError: - return self.skipTest( - "server too old: two phase transactions not supported.") - else: - mtp = int(cur.fetchone()[0]) - cnn.close() + cur = cnn.cursor() + try: + cur.execute("SHOW max_prepared_transactions;") + except psycopg2.ProgrammingError: + return self.skipTest( + "server too old: two phase transactions not supported.") + else: + mtp = int(cur.fetchone()[0]) + finally: + cnn.close() if not mtp: return self.skipTest( From de858b9cb2ef59760a1d1a4bdcf155b61e090e6e Mon Sep 17 00:00:00 2001 From: Jon Dufresne Date: Sun, 2 Feb 2020 16:20:52 -0800 Subject: [PATCH 3/8] Always close cursor objects in tests --- tests/test_async.py | 547 +++++++------ tests/test_async_keyword.py | 97 +-- tests/test_bug_gc.py | 6 +- tests/test_cancel.py | 64 +- tests/test_connection.py | 1264 +++++++++++++++---------------- tests/test_copy.py | 218 +++--- tests/test_cursor.py | 529 +++++++------ tests/test_dates.py | 68 +- tests/test_extras_dictcursor.py | 597 ++++++++------- tests/test_fast_executemany.py | 320 ++++---- tests/test_green.py | 105 +-- tests/test_ipaddress.py | 108 +-- tests/test_module.py | 30 +- tests/test_notify.py | 9 +- tests/test_quote.py | 147 ++-- tests/test_replication.py | 161 ++-- tests/test_sql.py | 90 +-- tests/test_transaction.py | 178 ++--- tests/test_types_basic.py | 141 ++-- tests/test_types_extras.py | 1110 ++++++++++++++------------- tests/test_with.py | 96 +-- tests/testutils.py | 22 +- 22 files changed, 2954 insertions(+), 2953 deletions(-) diff --git a/tests/test_async.py b/tests/test_async.py index 8624be4a..7a237b5e 100755 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -61,16 +61,18 @@ class AsyncTests(ConnectingTestCase): self.wait(self.conn) - curs = self.conn.cursor() - curs.execute(''' - CREATE TEMPORARY TABLE table1 ( - id int PRIMARY KEY - )''') - self.wait(curs) + with self.conn.cursor() as curs: + curs.execute(''' + CREATE TEMPORARY TABLE table1 ( + id int PRIMARY KEY + )''') + self.wait(curs) def test_connection_setup(self): cur = self.conn.cursor() sync_cur = self.sync_conn.cursor() + cur.close() + sync_cur.close() del cur, sync_cur self.assert_(self.conn.async_) @@ -90,159 +92,156 @@ class AsyncTests(ConnectingTestCase): self.conn.cursor, "name") def test_async_select(self): - cur = self.conn.cursor() - self.assertFalse(self.conn.isexecuting()) - cur.execute("select 'a'") - self.assertTrue(self.conn.isexecuting()) + with self.conn.cursor() as cur: + self.assertFalse(self.conn.isexecuting()) + cur.execute("select 'a'") + self.assertTrue(self.conn.isexecuting()) - self.wait(cur) + self.wait(cur) - self.assertFalse(self.conn.isexecuting()) - self.assertEquals(cur.fetchone()[0], "a") + self.assertFalse(self.conn.isexecuting()) + self.assertEquals(cur.fetchone()[0], "a") @slow @skip_before_postgres(8, 2) def test_async_callproc(self): - cur = self.conn.cursor() - cur.callproc("pg_sleep", (0.1, )) - self.assertTrue(self.conn.isexecuting()) + with self.conn.cursor() as cur: + cur.callproc("pg_sleep", (0.1, )) + self.assertTrue(self.conn.isexecuting()) - self.wait(cur) - self.assertFalse(self.conn.isexecuting()) - self.assertEquals(cur.fetchall()[0][0], '') + self.wait(cur) + self.assertFalse(self.conn.isexecuting()) + self.assertEquals(cur.fetchall()[0][0], '') @slow def test_async_after_async(self): - cur = self.conn.cursor() - cur2 = self.conn.cursor() - del cur2 + with self.conn.cursor() as cur: + cur2 = self.conn.cursor() + cur2.close() + del cur2 - cur.execute("insert into table1 values (1)") + cur.execute("insert into table1 values (1)") - # an async execute after an async one raises an exception - self.assertRaises(psycopg2.ProgrammingError, - cur.execute, "select * from table1") - # same for callproc - self.assertRaises(psycopg2.ProgrammingError, - cur.callproc, "version") - # but after you've waited it should be good - self.wait(cur) - cur.execute("select * from table1") - self.wait(cur) + # an async execute after an async one raises an exception + self.assertRaises(psycopg2.ProgrammingError, + cur.execute, "select * from table1") + # same for callproc + self.assertRaises(psycopg2.ProgrammingError, + cur.callproc, "version") + # but after you've waited it should be good + self.wait(cur) + cur.execute("select * from table1") + self.wait(cur) - self.assertEquals(cur.fetchall()[0][0], 1) + self.assertEquals(cur.fetchall()[0][0], 1) - cur.execute("delete from table1") - self.wait(cur) + cur.execute("delete from table1") + self.wait(cur) - cur.execute("select * from table1") - self.wait(cur) + cur.execute("select * from table1") + self.wait(cur) - self.assertEquals(cur.fetchone(), None) + self.assertEquals(cur.fetchone(), None) def test_fetch_after_async(self): - cur = self.conn.cursor() - cur.execute("select 'a'") + with self.conn.cursor() as cur: + cur.execute("select 'a'") - # a fetch after an asynchronous query should raise an error - self.assertRaises(psycopg2.ProgrammingError, - cur.fetchall) - # but after waiting it should work - self.wait(cur) - self.assertEquals(cur.fetchall()[0][0], "a") + # a fetch after an asynchronous query should raise an error + self.assertRaises(psycopg2.ProgrammingError, + cur.fetchall) + # but after waiting it should work + self.wait(cur) + self.assertEquals(cur.fetchall()[0][0], "a") def test_rollback_while_async(self): - cur = self.conn.cursor() + with self.conn.cursor() as cur: + cur.execute("select 'a'") - cur.execute("select 'a'") - - # a rollback should not work in asynchronous mode - self.assertRaises(psycopg2.ProgrammingError, self.conn.rollback) + # a rollback should not work in asynchronous mode + self.assertRaises(psycopg2.ProgrammingError, self.conn.rollback) def test_commit_while_async(self): - cur = self.conn.cursor() + with self.conn.cursor() as cur: + cur.execute("begin") + self.wait(cur) - cur.execute("begin") - self.wait(cur) + cur.execute("insert into table1 values (1)") - cur.execute("insert into table1 values (1)") + # a commit should not work in asynchronous mode + self.assertRaises(psycopg2.ProgrammingError, self.conn.commit) + self.assertTrue(self.conn.isexecuting()) - # a commit should not work in asynchronous mode - self.assertRaises(psycopg2.ProgrammingError, self.conn.commit) - self.assertTrue(self.conn.isexecuting()) + # but a manual commit should + self.wait(cur) + cur.execute("commit") + self.wait(cur) - # but a manual commit should - self.wait(cur) - cur.execute("commit") - self.wait(cur) + cur.execute("select * from table1") + self.wait(cur) + self.assertEquals(cur.fetchall()[0][0], 1) - cur.execute("select * from table1") - self.wait(cur) - self.assertEquals(cur.fetchall()[0][0], 1) + cur.execute("delete from table1") + self.wait(cur) - cur.execute("delete from table1") - self.wait(cur) - - cur.execute("select * from table1") - self.wait(cur) - self.assertEquals(cur.fetchone(), None) + cur.execute("select * from table1") + self.wait(cur) + self.assertEquals(cur.fetchone(), None) def test_set_parameters_while_async(self): - cur = self.conn.cursor() + with self.conn.cursor() as cur: + cur.execute("select 'c'") + self.assertTrue(self.conn.isexecuting()) - cur.execute("select 'c'") - self.assertTrue(self.conn.isexecuting()) + # getting transaction status works + self.assertEquals(self.conn.info.transaction_status, + ext.TRANSACTION_STATUS_ACTIVE) + self.assertTrue(self.conn.isexecuting()) - # getting transaction status works - self.assertEquals(self.conn.info.transaction_status, - ext.TRANSACTION_STATUS_ACTIVE) - self.assertTrue(self.conn.isexecuting()) + # setting connection encoding should fail + self.assertRaises(psycopg2.ProgrammingError, + self.conn.set_client_encoding, "LATIN1") - # setting connection encoding should fail - self.assertRaises(psycopg2.ProgrammingError, - self.conn.set_client_encoding, "LATIN1") - - # same for transaction isolation - self.assertRaises(psycopg2.ProgrammingError, - self.conn.set_isolation_level, 1) + # same for transaction isolation + self.assertRaises(psycopg2.ProgrammingError, + self.conn.set_isolation_level, 1) def test_reset_while_async(self): - cur = self.conn.cursor() - cur.execute("select 'c'") - self.assertTrue(self.conn.isexecuting()) + with self.conn.cursor() as cur: + cur.execute("select 'c'") + self.assertTrue(self.conn.isexecuting()) - # a reset should fail - self.assertRaises(psycopg2.ProgrammingError, self.conn.reset) + # a reset should fail + self.assertRaises(psycopg2.ProgrammingError, self.conn.reset) def test_async_iter(self): - cur = self.conn.cursor() + with self.conn.cursor() as cur: + cur.execute("begin") + self.wait(cur) + cur.execute(""" + insert into table1 values (1); + insert into table1 values (2); + insert into table1 values (3); + """) + self.wait(cur) + cur.execute("select id from table1 order by id") - cur.execute("begin") - self.wait(cur) - cur.execute(""" - insert into table1 values (1); - insert into table1 values (2); - insert into table1 values (3); - """) - self.wait(cur) - cur.execute("select id from table1 order by id") + # iteration fails if a query is underway + self.assertRaises(psycopg2.ProgrammingError, list, cur) - # iteration fails if a query is underway - self.assertRaises(psycopg2.ProgrammingError, list, cur) - - # but after it's done it should work - self.wait(cur) - self.assertEquals(list(cur), [(1, ), (2, ), (3, )]) - self.assertFalse(self.conn.isexecuting()) + # but after it's done it should work + self.wait(cur) + self.assertEquals(list(cur), [(1, ), (2, ), (3, )]) + self.assertFalse(self.conn.isexecuting()) def test_copy_while_async(self): - cur = self.conn.cursor() - cur.execute("select 'a'") + with self.conn.cursor() as cur: + cur.execute("select 'a'") - # copy should fail - self.assertRaises(psycopg2.ProgrammingError, - cur.copy_from, - StringIO("1\n3\n5\n\\.\n"), "table1") + # copy should fail + self.assertRaises(psycopg2.ProgrammingError, + cur.copy_from, + StringIO("1\n3\n5\n\\.\n"), "table1") def test_lobject_while_async(self): # large objects should be prohibited @@ -250,68 +249,68 @@ class AsyncTests(ConnectingTestCase): self.conn.lobject) def test_async_executemany(self): - cur = self.conn.cursor() - self.assertRaises( - psycopg2.ProgrammingError, - cur.executemany, "insert into table1 values (%s)", [1, 2, 3]) + with self.conn.cursor() as cur: + self.assertRaises( + psycopg2.ProgrammingError, + cur.executemany, "insert into table1 values (%s)", [1, 2, 3]) def test_async_scroll(self): - cur = self.conn.cursor() - cur.execute(""" - insert into table1 values (1); - insert into table1 values (2); - insert into table1 values (3); - """) - self.wait(cur) - cur.execute("select id from table1 order by id") + with self.conn.cursor() as cur: + cur.execute(""" + insert into table1 values (1); + insert into table1 values (2); + insert into table1 values (3); + """) + self.wait(cur) + cur.execute("select id from table1 order by id") - # scroll should fail if a query is underway - self.assertRaises(psycopg2.ProgrammingError, cur.scroll, 1) - self.assertTrue(self.conn.isexecuting()) + # scroll should fail if a query is underway + self.assertRaises(psycopg2.ProgrammingError, cur.scroll, 1) + self.assertTrue(self.conn.isexecuting()) - # but after it's done it should work - self.wait(cur) - cur.scroll(1) - self.assertEquals(cur.fetchall(), [(2, ), (3, )]) + # but after it's done it should work + self.wait(cur) + cur.scroll(1) + self.assertEquals(cur.fetchall(), [(2, ), (3, )]) - cur = self.conn.cursor() - cur.execute("select id from table1 order by id") - self.wait(cur) + with self.conn.cursor() as cur2: + cur.execute("select id from table1 order by id") + self.wait(cur) - cur2 = self.conn.cursor() - self.assertRaises(psycopg2.ProgrammingError, cur2.scroll, 1) + with self.conn.cursor() as cur2: + self.assertRaises(psycopg2.ProgrammingError, cur2.scroll, 1) - self.assertRaises(psycopg2.ProgrammingError, cur.scroll, 4) + self.assertRaises(psycopg2.ProgrammingError, cur.scroll, 4) - cur = self.conn.cursor() - cur.execute("select id from table1 order by id") - self.wait(cur) - cur.scroll(2) - cur.scroll(-1) - self.assertEquals(cur.fetchall(), [(2, ), (3, )]) + with self.conn.cursor() as cur: + cur.execute("select id from table1 order by id") + self.wait(cur) + cur.scroll(2) + cur.scroll(-1) + self.assertEquals(cur.fetchall(), [(2, ), (3, )]) def test_scroll(self): - cur = self.sync_conn.cursor() - cur.execute("create table table1 (id int)") - cur.execute(""" - insert into table1 values (1); - insert into table1 values (2); - insert into table1 values (3); - """) - cur.execute("select id from table1 order by id") - cur.scroll(2) - cur.scroll(-1) - self.assertEquals(cur.fetchall(), [(2, ), (3, )]) + with self.sync_conn.cursor() as cur: + cur.execute("create table table1 (id int)") + cur.execute(""" + insert into table1 values (1); + insert into table1 values (2); + insert into table1 values (3); + """) + cur.execute("select id from table1 order by id") + cur.scroll(2) + cur.scroll(-1) + self.assertEquals(cur.fetchall(), [(2, ), (3, )]) def test_async_dont_read_all(self): - cur = self.conn.cursor() - cur.execute("select repeat('a', 10000); select repeat('b', 10000)") + with self.conn.cursor() as cur: + cur.execute("select repeat('a', 10000); select repeat('b', 10000)") - # fetch the result - self.wait(cur) + # fetch the result + self.wait(cur) - # it should be the result of the second query - self.assertEquals(cur.fetchone()[0], "b" * 10000) + # it should be the result of the second query + self.assertEquals(cur.fetchone()[0], "b" * 10000) def test_async_subclass(self): class MyConn(ext.connection): @@ -326,15 +325,15 @@ class AsyncTests(ConnectingTestCase): @slow def test_flush_on_write(self): # a very large query requires a flush loop to be sent to the backend - curs = self.conn.cursor() - for mb in 1, 5, 10, 20, 50: - size = mb * 1024 * 1024 - stub = PollableStub(self.conn) - curs.execute("select %s;", ('x' * size,)) - self.wait(stub) - self.assertEqual(size, len(curs.fetchone()[0])) - if stub.polls.count(ext.POLL_WRITE) > 1: - return + with self.conn.cursor() as curs: + for mb in 1, 5, 10, 20, 50: + size = mb * 1024 * 1024 + stub = PollableStub(self.conn) + curs.execute("select %s;", ('x' * size,)) + self.wait(stub) + self.assertEqual(size, len(curs.fetchone()[0])) + if stub.polls.count(ext.POLL_WRITE) > 1: + return # This is more a testing glitch than an error: it happens # on high load on linux: probably because the kernel has more @@ -343,112 +342,108 @@ class AsyncTests(ConnectingTestCase): warnings.warn("sending a large query didn't trigger block on write.") def test_sync_poll(self): - cur = self.sync_conn.cursor() - cur.execute("select 1") - # polling with a sync query works - cur.connection.poll() - self.assertEquals(cur.fetchone()[0], 1) + with self.sync_conn.cursor() as cur: + cur.execute("select 1") + # polling with a sync query works + cur.connection.poll() + self.assertEquals(cur.fetchone()[0], 1) @slow def test_notify(self): - cur = self.conn.cursor() - sync_cur = self.sync_conn.cursor() + with self.conn.cursor() as cur, self.sync_conn.cursor() as sync_cur: + sync_cur.execute("listen test_notify") + self.sync_conn.commit() + cur.execute("notify test_notify") + self.wait(cur) - sync_cur.execute("listen test_notify") - self.sync_conn.commit() - cur.execute("notify test_notify") - self.wait(cur) + self.assertEquals(self.sync_conn.notifies, []) - self.assertEquals(self.sync_conn.notifies, []) - - pid = self.conn.info.backend_pid - for _ in range(5): - self.wait(self.sync_conn) - if not self.sync_conn.notifies: - time.sleep(0.5) - continue - self.assertEquals(len(self.sync_conn.notifies), 1) - self.assertEquals(self.sync_conn.notifies.pop(), - (pid, "test_notify")) - return + pid = self.conn.info.backend_pid + for _ in range(5): + self.wait(self.sync_conn) + if not self.sync_conn.notifies: + time.sleep(0.5) + continue + self.assertEquals(len(self.sync_conn.notifies), 1) + self.assertEquals(self.sync_conn.notifies.pop(), + (pid, "test_notify")) + return self.fail("No NOTIFY in 2.5 seconds") def test_async_fetch_wrong_cursor(self): - cur1 = self.conn.cursor() - cur2 = self.conn.cursor() - cur1.execute("select 1") + with self.conn.cursor() as cur1, self.conn.cursor() as cur2: + cur1.execute("select 1") - self.wait(cur1) - self.assertFalse(self.conn.isexecuting()) - # fetching from a cursor with no results is an error - self.assertRaises(psycopg2.ProgrammingError, cur2.fetchone) - # fetching from the correct cursor works - self.assertEquals(cur1.fetchone()[0], 1) + self.wait(cur1) + self.assertFalse(self.conn.isexecuting()) + # fetching from a cursor with no results is an error + self.assertRaises(psycopg2.ProgrammingError, cur2.fetchone) + # fetching from the correct cursor works + self.assertEquals(cur1.fetchone()[0], 1) def test_error(self): - cur = self.conn.cursor() - cur.execute("insert into table1 values (%s)", (1, )) - self.wait(cur) - cur.execute("insert into table1 values (%s)", (1, )) - # this should fail - self.assertRaises(psycopg2.IntegrityError, self.wait, cur) - cur.execute("insert into table1 values (%s); " - "insert into table1 values (%s)", (2, 2)) - # this should fail as well - self.assertRaises(psycopg2.IntegrityError, self.wait, cur) - # but this should work - cur.execute("insert into table1 values (%s)", (2, )) - self.wait(cur) - # and the cursor should be usable afterwards - cur.execute("insert into table1 values (%s)", (3, )) - self.wait(cur) - cur.execute("select * from table1 order by id") - self.wait(cur) - self.assertEquals(cur.fetchall(), [(1, ), (2, ), (3, )]) - cur.execute("delete from table1") - self.wait(cur) + with self.conn.cursor() as cur: + cur.execute("insert into table1 values (%s)", (1, )) + self.wait(cur) + cur.execute("insert into table1 values (%s)", (1, )) + # this should fail + self.assertRaises(psycopg2.IntegrityError, self.wait, cur) + cur.execute("insert into table1 values (%s); " + "insert into table1 values (%s)", (2, 2)) + # this should fail as well + self.assertRaises(psycopg2.IntegrityError, self.wait, cur) + # but this should work + cur.execute("insert into table1 values (%s)", (2, )) + self.wait(cur) + # and the cursor should be usable afterwards + cur.execute("insert into table1 values (%s)", (3, )) + self.wait(cur) + cur.execute("select * from table1 order by id") + self.wait(cur) + self.assertEquals(cur.fetchall(), [(1, ), (2, ), (3, )]) + cur.execute("delete from table1") + self.wait(cur) def test_stop_on_first_error(self): - cur = self.conn.cursor() - cur.execute("select 1; select x; select 1/0; select 2") - self.assertRaises(psycopg2.errors.UndefinedColumn, self.wait, cur) + with self.conn.cursor() as cur: + cur.execute("select 1; select x; select 1/0; select 2") + self.assertRaises(psycopg2.errors.UndefinedColumn, self.wait, cur) - cur.execute("select 1") - self.wait(cur) - self.assertEqual(cur.fetchone(), (1,)) + cur.execute("select 1") + self.wait(cur) + self.assertEqual(cur.fetchone(), (1,)) def test_error_two_cursors(self): - cur = self.conn.cursor() - cur2 = self.conn.cursor() - cur.execute("select * from no_such_table") - self.assertRaises(psycopg2.ProgrammingError, self.wait, cur) - cur2.execute("select 1") - self.wait(cur2) - self.assertEquals(cur2.fetchone()[0], 1) + with self.conn.cursor() as cur, self.conn.cursor() as cur2: + cur.execute("select * from no_such_table") + self.assertRaises(psycopg2.ProgrammingError, self.wait, cur) + cur2.execute("select 1") + self.wait(cur2) + self.assertEquals(cur2.fetchone()[0], 1) def test_notices(self): del self.conn.notices[:] - cur = self.conn.cursor() - if self.conn.info.server_version >= 90300: - cur.execute("set client_min_messages=debug1") + with self.conn.cursor() as cur: + if self.conn.info.server_version >= 90300: + cur.execute("set client_min_messages=debug1") + self.wait(cur) + cur.execute("create temp table chatty (id serial primary key);") self.wait(cur) - cur.execute("create temp table chatty (id serial primary key);") - self.wait(cur) - self.assertEqual("CREATE TABLE", cur.statusmessage) - self.assert_(self.conn.notices) + self.assertEqual("CREATE TABLE", cur.statusmessage) + self.assert_(self.conn.notices) def test_async_cursor_gone(self): - cur = self.conn.cursor() - cur.execute("select 42;") + with self.conn.cursor() as cur: + cur.execute("select 42;") del cur gc.collect() self.assertRaises(psycopg2.InterfaceError, self.wait, self.conn) # The connection is still usable - cur = self.conn.cursor() - cur.execute("select 42;") - self.wait(self.conn) - self.assertEqual(cur.fetchone(), (42,)) + with self.conn.cursor() as cur: + cur.execute("select 42;") + self.wait(self.conn) + self.assertEqual(cur.fetchone(), (42,)) def test_async_connection_error_message(self): cnn = psycopg2.connect('dbname=thisdatabasedoesntexist', async_=True) @@ -464,40 +459,40 @@ class AsyncTests(ConnectingTestCase): @skip_before_postgres(8, 2) def test_copy_no_hang(self): - cur = self.conn.cursor() - cur.execute("copy (select 1) to stdout") - self.assertRaises(psycopg2.ProgrammingError, self.wait, self.conn) + with self.conn.cursor() as cur: + cur.execute("copy (select 1) to stdout") + self.assertRaises(psycopg2.ProgrammingError, self.wait, self.conn) @slow @skip_before_postgres(9, 0) def test_non_block_after_notification(self): from select import select - cur = self.conn.cursor() - cur.execute(""" - select 1; - do $$ - begin - raise notice 'hello'; - end - $$ language plpgsql; - select pg_sleep(1); - """) + with self.conn.cursor() as cur: + cur.execute(""" + select 1; + do $$ + begin + raise notice 'hello'; + end + $$ language plpgsql; + select pg_sleep(1); + """) - polls = 0 - while True: - state = self.conn.poll() - if state == psycopg2.extensions.POLL_OK: - break - elif state == psycopg2.extensions.POLL_READ: - select([self.conn], [], [], 0.1) - elif state == psycopg2.extensions.POLL_WRITE: - select([], [self.conn], [], 0.1) - else: - raise Exception("Unexpected result from poll: %r", state) - polls += 1 + polls = 0 + while True: + state = self.conn.poll() + if state == psycopg2.extensions.POLL_OK: + break + elif state == psycopg2.extensions.POLL_READ: + select([self.conn], [], [], 0.1) + elif state == psycopg2.extensions.POLL_WRITE: + select([], [self.conn], [], 0.1) + else: + raise Exception("Unexpected result from poll: %r", state) + polls += 1 - self.assert_(polls >= 8, polls) + self.assert_(polls >= 8, polls) def test_poll_noop(self): self.conn.poll() diff --git a/tests/test_async_keyword.py b/tests/test_async_keyword.py index f8e50afe..dbedec01 100755 --- a/tests/test_async_keyword.py +++ b/tests/test_async_keyword.py @@ -47,16 +47,18 @@ class AsyncTests(ConnectingTestCase): self.wait(self.conn) - curs = self.conn.cursor() - curs.execute(''' - CREATE TEMPORARY TABLE table1 ( - id int PRIMARY KEY - )''') - self.wait(curs) + with self.conn.cursor() as curs: + curs.execute(''' + CREATE TEMPORARY TABLE table1 ( + id int PRIMARY KEY + )''') + self.wait(curs) def test_connection_setup(self): cur = self.conn.cursor() sync_cur = self.sync_conn.cursor() + cur.close() + sync_cur.close() del cur, sync_cur self.assert_(self.conn.async) @@ -97,11 +99,11 @@ class CancelTests(ConnectingTestCase): def setUp(self): ConnectingTestCase.setUp(self) - cur = self.conn.cursor() - cur.execute(''' - CREATE TEMPORARY TABLE table1 ( - id int PRIMARY KEY - )''') + with self.conn.cursor() as cur: + cur.execute(''' + CREATE TEMPORARY TABLE table1 ( + id int PRIMARY KEY + )''') self.conn.commit() @slow @@ -110,16 +112,16 @@ class CancelTests(ConnectingTestCase): async_conn = psycopg2.connect(dsn, async=True) self.assertRaises(psycopg2.OperationalError, async_conn.cancel) extras.wait_select(async_conn) - cur = async_conn.cursor() - cur.execute("select pg_sleep(10)") - time.sleep(1) - self.assertTrue(async_conn.isexecuting()) - async_conn.cancel() - self.assertRaises(psycopg2.extensions.QueryCanceledError, - extras.wait_select, async_conn) - cur.execute("select 1") - extras.wait_select(async_conn) - self.assertEqual(cur.fetchall(), [(1, )]) + with async_conn.cursor() as cur: + cur.execute("select pg_sleep(10)") + time.sleep(1) + self.assertTrue(async_conn.isexecuting()) + async_conn.cancel() + self.assertRaises(psycopg2.extensions.QueryCanceledError, + extras.wait_select, async_conn) + cur.execute("select 1") + extras.wait_select(async_conn) + self.assertEqual(cur.fetchall(), [(1, )]) async_conn.close() def test_async_connection_cancel(self): @@ -183,41 +185,40 @@ class AsyncReplicationTest(ReplicationTestCase): if conn is None: return - cur = conn.cursor() + with conn.cursor() as cur: + self.create_replication_slot(cur, output_plugin='test_decoding') + self.wait(cur) - self.create_replication_slot(cur, output_plugin='test_decoding') - self.wait(cur) + cur.start_replication(self.slot) + self.wait(cur) - cur.start_replication(self.slot) - self.wait(cur) + self.make_replication_events() - self.make_replication_events() + self.msg_count = 0 - self.msg_count = 0 + def consume(msg): + # just check the methods + "%s: %s" % (cur.io_timestamp, repr(msg)) + "%s: %s" % (cur.feedback_timestamp, repr(msg)) - def consume(msg): - # just check the methods - "%s: %s" % (cur.io_timestamp, repr(msg)) - "%s: %s" % (cur.feedback_timestamp, repr(msg)) + self.msg_count += 1 + if self.msg_count > 3: + cur.send_feedback(reply=True) + raise StopReplication() - self.msg_count += 1 - if self.msg_count > 3: - cur.send_feedback(reply=True) - raise StopReplication() + cur.send_feedback(flush_lsn=msg.data_start) - cur.send_feedback(flush_lsn=msg.data_start) + # cannot be used in asynchronous mode + self.assertRaises(psycopg2.ProgrammingError, cur.consume_stream, consume) - # cannot be used in asynchronous mode - self.assertRaises(psycopg2.ProgrammingError, cur.consume_stream, consume) - - def process_stream(): - while True: - msg = cur.read_message() - if msg: - consume(msg) - else: - select([cur], [], []) - self.assertRaises(StopReplication, process_stream) + def process_stream(): + while True: + msg = cur.read_message() + if msg: + consume(msg) + else: + select([cur], [], []) + self.assertRaises(StopReplication, process_stream) def test_suite(): diff --git a/tests/test_bug_gc.py b/tests/test_bug_gc.py index 3d0b789e..ee82336d 100755 --- a/tests/test_bug_gc.py +++ b/tests/test_bug_gc.py @@ -39,9 +39,9 @@ class StolenReferenceTestCase(ConnectingTestCase): return 42 UUID = psycopg2.extensions.new_type((2950,), "UUID", fish) psycopg2.extensions.register_type(UUID, self.conn) - curs = self.conn.cursor() - curs.execute("select 'b5219e01-19ab-4994-b71e-149225dc51e4'::uuid") - curs.fetchone() + with self.conn.cursor() as curs: + curs.execute("select 'b5219e01-19ab-4994-b71e-149225dc51e4'::uuid") + curs.fetchone() def test_suite(): diff --git a/tests/test_cancel.py b/tests/test_cancel.py index 06477edc..b4daba3e 100755 --- a/tests/test_cancel.py +++ b/tests/test_cancel.py @@ -41,11 +41,11 @@ class CancelTests(ConnectingTestCase): def setUp(self): ConnectingTestCase.setUp(self) - cur = self.conn.cursor() - cur.execute(''' - CREATE TEMPORARY TABLE table1 ( - id int PRIMARY KEY - )''') + with self.conn.cursor() as cur: + cur.execute(''' + CREATE TEMPORARY TABLE table1 ( + id int PRIMARY KEY + )''') self.conn.commit() def test_empty_cancel(self): @@ -57,25 +57,25 @@ class CancelTests(ConnectingTestCase): errors = [] def neverending(conn): - cur = conn.cursor() - try: - self.assertRaises(psycopg2.extensions.QueryCanceledError, - cur.execute, "select pg_sleep(60)") - # make sure the connection still works - conn.rollback() - cur.execute("select 1") - self.assertEqual(cur.fetchall(), [(1, )]) - except Exception as e: - errors.append(e) - raise + with conn.cursor() as cur: + try: + self.assertRaises(psycopg2.extensions.QueryCanceledError, + cur.execute, "select pg_sleep(60)") + # make sure the connection still works + conn.rollback() + cur.execute("select 1") + self.assertEqual(cur.fetchall(), [(1, )]) + except Exception as e: + errors.append(e) + raise def canceller(conn): - cur = conn.cursor() - try: - conn.cancel() - except Exception as e: - errors.append(e) - raise + with conn.cursor() as cur: + try: + conn.cancel() + except Exception as e: + errors.append(e) + raise del cur thread1 = threading.Thread(target=neverending, args=(self.conn, )) @@ -95,16 +95,16 @@ class CancelTests(ConnectingTestCase): async_conn = psycopg2.connect(dsn, async_=True) self.assertRaises(psycopg2.OperationalError, async_conn.cancel) extras.wait_select(async_conn) - cur = async_conn.cursor() - cur.execute("select pg_sleep(10)") - time.sleep(1) - self.assertTrue(async_conn.isexecuting()) - async_conn.cancel() - self.assertRaises(psycopg2.extensions.QueryCanceledError, - extras.wait_select, async_conn) - cur.execute("select 1") - extras.wait_select(async_conn) - self.assertEqual(cur.fetchall(), [(1, )]) + with async_conn.cursor() as cur: + cur.execute("select pg_sleep(10)") + time.sleep(1) + self.assertTrue(async_conn.isexecuting()) + async_conn.cancel() + self.assertRaises(psycopg2.extensions.QueryCanceledError, + extras.wait_select, async_conn) + cur.execute("select 1") + extras.wait_select(async_conn) + self.assertEqual(cur.fetchall(), [(1, )]) async_conn.close() def test_async_connection_cancel(self): diff --git a/tests/test_connection.py b/tests/test_connection.py index f1668d55..c724a7fa 100755 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -80,9 +80,9 @@ class ConnectionTests(ConnectingTestCase): def test_cleanup_on_badconn_close(self): # ticket #148 conn = self.conn - cur = conn.cursor() - self.assertRaises(psycopg2.OperationalError, - cur.execute, "select pg_terminate_backend(pg_backend_pid())") + with conn.cursor() as cur: + self.assertRaises(psycopg2.OperationalError, + cur.execute, "select pg_terminate_backend(pg_backend_pid())") self.assertEqual(conn.closed, 2) conn.close() @@ -113,87 +113,87 @@ class ConnectionTests(ConnectingTestCase): def test_notices(self): conn = self.conn - cur = conn.cursor() - if self.conn.info.server_version >= 90300: - cur.execute("set client_min_messages=debug1") - cur.execute("create temp table chatty (id serial primary key);") - self.assertEqual("CREATE TABLE", cur.statusmessage) - self.assert_(conn.notices) + with conn.cursor() as cur: + if self.conn.info.server_version >= 90300: + cur.execute("set client_min_messages=debug1") + cur.execute("create temp table chatty (id serial primary key);") + self.assertEqual("CREATE TABLE", cur.statusmessage) + self.assert_(conn.notices) def test_notices_consistent_order(self): conn = self.conn - cur = conn.cursor() - if self.conn.info.server_version >= 90300: - cur.execute("set client_min_messages=debug1") - cur.execute(""" - create temp table table1 (id serial); - create temp table table2 (id serial); - """) - cur.execute(""" - create temp table table3 (id serial); - create temp table table4 (id serial); - """) - self.assertEqual(4, len(conn.notices)) - self.assert_('table1' in conn.notices[0]) - self.assert_('table2' in conn.notices[1]) - self.assert_('table3' in conn.notices[2]) - self.assert_('table4' in conn.notices[3]) + with conn.cursor() as cur: + if self.conn.info.server_version >= 90300: + cur.execute("set client_min_messages=debug1") + cur.execute(""" + create temp table table1 (id serial); + create temp table table2 (id serial); + """) + cur.execute(""" + create temp table table3 (id serial); + create temp table table4 (id serial); + """) + self.assertEqual(4, len(conn.notices)) + self.assert_('table1' in conn.notices[0]) + self.assert_('table2' in conn.notices[1]) + self.assert_('table3' in conn.notices[2]) + self.assert_('table4' in conn.notices[3]) @slow def test_notices_limited(self): conn = self.conn - cur = conn.cursor() - if self.conn.info.server_version >= 90300: - cur.execute("set client_min_messages=debug1") - for i in range(0, 100, 10): - sql = " ".join(["create temp table table%d (id serial);" % j - for j in range(i, i + 10)]) - cur.execute(sql) + with conn.cursor() as cur: + if self.conn.info.server_version >= 90300: + cur.execute("set client_min_messages=debug1") + for i in range(0, 100, 10): + sql = " ".join(["create temp table table%d (id serial);" % j + for j in range(i, i + 10)]) + cur.execute(sql) - self.assertEqual(50, len(conn.notices)) - self.assert_('table99' in conn.notices[-1], conn.notices[-1]) + self.assertEqual(50, len(conn.notices)) + self.assert_('table99' in conn.notices[-1], conn.notices[-1]) @slow def test_notices_deque(self): conn = self.conn self.conn.notices = deque() - cur = conn.cursor() - if self.conn.info.server_version >= 90300: - cur.execute("set client_min_messages=debug1") + with conn.cursor() as cur: + if self.conn.info.server_version >= 90300: + cur.execute("set client_min_messages=debug1") - cur.execute(""" - create temp table table1 (id serial); - create temp table table2 (id serial); - """) - cur.execute(""" - create temp table table3 (id serial); - create temp table table4 (id serial);""") - self.assertEqual(len(conn.notices), 4) - self.assert_('table1' in conn.notices.popleft()) - self.assert_('table2' in conn.notices.popleft()) - self.assert_('table3' in conn.notices.popleft()) - self.assert_('table4' in conn.notices.popleft()) - self.assertEqual(len(conn.notices), 0) + cur.execute(""" + create temp table table1 (id serial); + create temp table table2 (id serial); + """) + cur.execute(""" + create temp table table3 (id serial); + create temp table table4 (id serial);""") + self.assertEqual(len(conn.notices), 4) + self.assert_('table1' in conn.notices.popleft()) + self.assert_('table2' in conn.notices.popleft()) + self.assert_('table3' in conn.notices.popleft()) + self.assert_('table4' in conn.notices.popleft()) + self.assertEqual(len(conn.notices), 0) - # not limited, but no error - for i in range(0, 100, 10): - sql = " ".join(["create temp table table2_%d (id serial);" % j - for j in range(i, i + 10)]) - cur.execute(sql) + # not limited, but no error + for i in range(0, 100, 10): + sql = " ".join(["create temp table table2_%d (id serial);" % j + for j in range(i, i + 10)]) + cur.execute(sql) - self.assertEqual(len([n for n in conn.notices if 'CREATE TABLE' in n]), - 100) + self.assertEqual(len([n for n in conn.notices if 'CREATE TABLE' in n]), + 100) def test_notices_noappend(self): conn = self.conn self.conn.notices = None # will make an error swallowes ok - cur = conn.cursor() - if self.conn.info.server_version >= 90300: - cur.execute("set client_min_messages=debug1") + with conn.cursor() as cur: + if self.conn.info.server_version >= 90300: + cur.execute("set client_min_messages=debug1") - cur.execute("create temp table table1 (id serial);") + cur.execute("create temp table table1 (id serial);") - self.assertEqual(self.conn.notices, None) + self.assertEqual(self.conn.notices, None) def test_server_version(self): self.assert_(self.conn.server_version) @@ -233,10 +233,10 @@ class ConnectionTests(ConnectingTestCase): def test_encoding_name(self): self.conn.set_client_encoding("EUC_JP") # conn.encoding is 'EUCJP' now. - cur = self.conn.cursor() - ext.register_type(ext.UNICODE, cur) - cur.execute("select 'foo'::text;") - self.assertEqual(cur.fetchone()[0], u'foo') + with self.conn.cursor() as cur: + ext.register_type(ext.UNICODE, cur) + cur.execute("select 'foo'::text;") + self.assertEqual(cur.fetchone()[0], u'foo') def test_connect_nonnormal_envvar(self): # We must perform encoding normalization at connection time @@ -281,14 +281,14 @@ class ConnectionTests(ConnectingTestCase): while conn.notices: notices.append((2, conn.notices.pop())) - cur = conn.cursor() - t1 = threading.Thread(target=committer) - t1.start() - for i in range(1000): - cur.execute("select %s;", (i,)) - conn.commit() - while conn.notices: - notices.append((1, conn.notices.pop())) + with conn.cursor() as cur: + t1 = threading.Thread(target=committer) + t1.start() + for i in range(1000): + cur.execute("select %s;", (i,)) + conn.commit() + while conn.notices: + notices.append((1, conn.notices.pop())) # Stop the committer thread stop.append(True) @@ -297,37 +297,37 @@ class ConnectionTests(ConnectingTestCase): def test_connect_cursor_factory(self): conn = self.connect(cursor_factory=psycopg2.extras.DictCursor) - cur = conn.cursor() - cur.execute("select 1 as a") - self.assertEqual(cur.fetchone()['a'], 1) + with conn.cursor() as cur: + cur.execute("select 1 as a") + self.assertEqual(cur.fetchone()['a'], 1) def test_cursor_factory(self): self.assertEqual(self.conn.cursor_factory, None) - cur = self.conn.cursor() - cur.execute("select 1 as a") - self.assertRaises(TypeError, (lambda r: r['a']), cur.fetchone()) + with self.conn.cursor() as cur: + cur.execute("select 1 as a") + self.assertRaises(TypeError, (lambda r: r['a']), cur.fetchone()) self.conn.cursor_factory = psycopg2.extras.DictCursor self.assertEqual(self.conn.cursor_factory, psycopg2.extras.DictCursor) - cur = self.conn.cursor() - cur.execute("select 1 as a") - self.assertEqual(cur.fetchone()['a'], 1) + with self.conn.cursor() as cur: + cur.execute("select 1 as a") + self.assertEqual(cur.fetchone()['a'], 1) self.conn.cursor_factory = None self.assertEqual(self.conn.cursor_factory, None) - cur = self.conn.cursor() - cur.execute("select 1 as a") - self.assertRaises(TypeError, (lambda r: r['a']), cur.fetchone()) + with self.conn.cursor() as cur: + cur.execute("select 1 as a") + self.assertRaises(TypeError, (lambda r: r['a']), cur.fetchone()) def test_cursor_factory_none(self): # issue #210 conn = self.connect() - cur = conn.cursor(cursor_factory=None) - self.assertEqual(type(cur), ext.cursor) + with conn.cursor(cursor_factory=None) as cur: + self.assertEqual(type(cur), ext.cursor) conn = self.connect(cursor_factory=psycopg2.extras.DictCursor) - cur = conn.cursor(cursor_factory=None) - self.assertEqual(type(cur), psycopg2.extras.DictCursor) + with conn.cursor(cursor_factory=None) as cur: + self.assertEqual(type(cur), psycopg2.extras.DictCursor) def test_failed_init_status(self): class SubConnection(ext.connection): @@ -583,179 +583,175 @@ class IsolationLevelsTestCase(ConnectingTestCase): def test_set_isolation_level(self): conn = self.connect() - curs = conn.cursor() + with conn.cursor() as curs: + levels = [ + ('read uncommitted', + ext.ISOLATION_LEVEL_READ_UNCOMMITTED), + ('read committed', ext.ISOLATION_LEVEL_READ_COMMITTED), + ('repeatable read', ext.ISOLATION_LEVEL_REPEATABLE_READ), + ('serializable', ext.ISOLATION_LEVEL_SERIALIZABLE), + ] + for name, level in levels: + conn.set_isolation_level(level) - levels = [ - ('read uncommitted', - ext.ISOLATION_LEVEL_READ_UNCOMMITTED), - ('read committed', ext.ISOLATION_LEVEL_READ_COMMITTED), - ('repeatable read', ext.ISOLATION_LEVEL_REPEATABLE_READ), - ('serializable', ext.ISOLATION_LEVEL_SERIALIZABLE), - ] - for name, level in levels: - conn.set_isolation_level(level) + # the only values available on prehistoric PG versions + if conn.info.server_version < 80000: + if level in ( + ext.ISOLATION_LEVEL_READ_UNCOMMITTED, + ext.ISOLATION_LEVEL_REPEATABLE_READ): + name, level = levels[levels.index((name, level)) + 1] - # the only values available on prehistoric PG versions - if conn.info.server_version < 80000: - if level in ( - ext.ISOLATION_LEVEL_READ_UNCOMMITTED, - ext.ISOLATION_LEVEL_REPEATABLE_READ): - name, level = levels[levels.index((name, level)) + 1] + self.assertEqual(conn.isolation_level, level) - self.assertEqual(conn.isolation_level, level) + curs.execute('show transaction_isolation;') + got_name = curs.fetchone()[0] - curs.execute('show transaction_isolation;') - got_name = curs.fetchone()[0] + self.assertEqual(name, got_name) + conn.commit() - self.assertEqual(name, got_name) - conn.commit() - - self.assertRaises(ValueError, conn.set_isolation_level, -1) - self.assertRaises(ValueError, conn.set_isolation_level, 5) + self.assertRaises(ValueError, conn.set_isolation_level, -1) + self.assertRaises(ValueError, conn.set_isolation_level, 5) def test_set_isolation_level_autocommit(self): conn = self.connect() - curs = conn.cursor() + with conn.cursor() as curs: + conn.set_isolation_level(ext.ISOLATION_LEVEL_AUTOCOMMIT) + self.assertEqual(conn.isolation_level, ext.ISOLATION_LEVEL_DEFAULT) + self.assert_(conn.autocommit) - conn.set_isolation_level(ext.ISOLATION_LEVEL_AUTOCOMMIT) - self.assertEqual(conn.isolation_level, ext.ISOLATION_LEVEL_DEFAULT) - self.assert_(conn.autocommit) + conn.isolation_level = 'serializable' + self.assertEqual(conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE) + self.assert_(conn.autocommit) - conn.isolation_level = 'serializable' - self.assertEqual(conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE) - self.assert_(conn.autocommit) - - curs.execute('show transaction_isolation;') - self.assertEqual(curs.fetchone()[0], 'serializable') + curs.execute('show transaction_isolation;') + self.assertEqual(curs.fetchone()[0], 'serializable') def test_set_isolation_level_default(self): conn = self.connect() - curs = conn.cursor() + with conn.cursor() as curs: + conn.autocommit = True + curs.execute("set default_transaction_isolation to 'read committed'") - conn.autocommit = True - curs.execute("set default_transaction_isolation to 'read committed'") + conn.autocommit = False + conn.set_isolation_level(ext.ISOLATION_LEVEL_SERIALIZABLE) + self.assertEqual(conn.isolation_level, + ext.ISOLATION_LEVEL_SERIALIZABLE) + curs.execute("show transaction_isolation") + self.assertEqual(curs.fetchone()[0], "serializable") - conn.autocommit = False - conn.set_isolation_level(ext.ISOLATION_LEVEL_SERIALIZABLE) - self.assertEqual(conn.isolation_level, - ext.ISOLATION_LEVEL_SERIALIZABLE) - curs.execute("show transaction_isolation") - self.assertEqual(curs.fetchone()[0], "serializable") - - conn.rollback() - conn.set_isolation_level(ext.ISOLATION_LEVEL_DEFAULT) - curs.execute("show transaction_isolation") - self.assertEqual(curs.fetchone()[0], "read committed") + conn.rollback() + conn.set_isolation_level(ext.ISOLATION_LEVEL_DEFAULT) + curs.execute("show transaction_isolation") + self.assertEqual(curs.fetchone()[0], "read committed") def test_set_isolation_level_abort(self): conn = self.connect() - cur = conn.cursor() + with conn.cursor() as cur: + self.assertEqual(ext.TRANSACTION_STATUS_IDLE, + conn.info.transaction_status) + cur.execute("insert into isolevel values (10);") + self.assertEqual(ext.TRANSACTION_STATUS_INTRANS, + conn.info.transaction_status) - self.assertEqual(ext.TRANSACTION_STATUS_IDLE, - conn.info.transaction_status) - cur.execute("insert into isolevel values (10);") - self.assertEqual(ext.TRANSACTION_STATUS_INTRANS, - conn.info.transaction_status) + conn.set_isolation_level( + psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE) + self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE, + conn.info.transaction_status) + cur.execute("select count(*) from isolevel;") + self.assertEqual(0, cur.fetchone()[0]) - conn.set_isolation_level( - psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE) - self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE, - conn.info.transaction_status) - cur.execute("select count(*) from isolevel;") - self.assertEqual(0, cur.fetchone()[0]) + cur.execute("insert into isolevel values (10);") + self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_INTRANS, + conn.info.transaction_status) + conn.set_isolation_level( + psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) + self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE, + conn.info.transaction_status) + cur.execute("select count(*) from isolevel;") + self.assertEqual(0, cur.fetchone()[0]) - cur.execute("insert into isolevel values (10);") - self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_INTRANS, - conn.info.transaction_status) - conn.set_isolation_level( - psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) - self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE, - conn.info.transaction_status) - cur.execute("select count(*) from isolevel;") - self.assertEqual(0, cur.fetchone()[0]) - - cur.execute("insert into isolevel values (10);") - self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE, - conn.info.transaction_status) - conn.set_isolation_level( - psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED) - self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE, - conn.info.transaction_status) - cur.execute("select count(*) from isolevel;") - self.assertEqual(1, cur.fetchone()[0]) - self.assertEqual(conn.isolation_level, - psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED) + cur.execute("insert into isolevel values (10);") + self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE, + conn.info.transaction_status) + conn.set_isolation_level( + psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED) + self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE, + conn.info.transaction_status) + cur.execute("select count(*) from isolevel;") + self.assertEqual(1, cur.fetchone()[0]) + self.assertEqual(conn.isolation_level, + psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED) def test_isolation_level_autocommit(self): cnn1 = self.connect() cnn2 = self.connect() cnn2.set_isolation_level(ext.ISOLATION_LEVEL_AUTOCOMMIT) - cur1 = cnn1.cursor() - cur1.execute("select count(*) from isolevel;") - self.assertEqual(0, cur1.fetchone()[0]) - cnn1.commit() + with cnn1.cursor() as cur1: + cur1.execute("select count(*) from isolevel;") + self.assertEqual(0, cur1.fetchone()[0]) + cnn1.commit() - cur2 = cnn2.cursor() - cur2.execute("insert into isolevel values (10);") + with cnn2.cursor() as cur2: + cur2.execute("insert into isolevel values (10);") - cur1.execute("select count(*) from isolevel;") - self.assertEqual(1, cur1.fetchone()[0]) + cur1.execute("select count(*) from isolevel;") + self.assertEqual(1, cur1.fetchone()[0]) def test_isolation_level_read_committed(self): cnn1 = self.connect() cnn2 = self.connect() cnn2.set_isolation_level(ext.ISOLATION_LEVEL_READ_COMMITTED) - cur1 = cnn1.cursor() - cur1.execute("select count(*) from isolevel;") - self.assertEqual(0, cur1.fetchone()[0]) - cnn1.commit() + with cnn1.cursor() as cur1: + cur1.execute("select count(*) from isolevel;") + self.assertEqual(0, cur1.fetchone()[0]) + cnn1.commit() - cur2 = cnn2.cursor() - cur2.execute("insert into isolevel values (10);") - cur1.execute("insert into isolevel values (20);") + with cnn2.cursor() as cur2: + cur2.execute("insert into isolevel values (10);") + cur1.execute("insert into isolevel values (20);") - cur2.execute("select count(*) from isolevel;") - self.assertEqual(1, cur2.fetchone()[0]) - cnn1.commit() - cur2.execute("select count(*) from isolevel;") - self.assertEqual(2, cur2.fetchone()[0]) + cur2.execute("select count(*) from isolevel;") + self.assertEqual(1, cur2.fetchone()[0]) + cnn1.commit() + cur2.execute("select count(*) from isolevel;") + self.assertEqual(2, cur2.fetchone()[0]) - cur1.execute("select count(*) from isolevel;") - self.assertEqual(1, cur1.fetchone()[0]) - cnn2.commit() - cur1.execute("select count(*) from isolevel;") - self.assertEqual(2, cur1.fetchone()[0]) + cur1.execute("select count(*) from isolevel;") + self.assertEqual(1, cur1.fetchone()[0]) + cnn2.commit() + cur1.execute("select count(*) from isolevel;") + self.assertEqual(2, cur1.fetchone()[0]) def test_isolation_level_serializable(self): cnn1 = self.connect() cnn2 = self.connect() cnn2.set_isolation_level(ext.ISOLATION_LEVEL_SERIALIZABLE) - cur1 = cnn1.cursor() - cur1.execute("select count(*) from isolevel;") - self.assertEqual(0, cur1.fetchone()[0]) - cnn1.commit() + with cnn1.cursor() as cur1: + cur1.execute("select count(*) from isolevel;") + self.assertEqual(0, cur1.fetchone()[0]) + cnn1.commit() - cur2 = cnn2.cursor() - cur2.execute("insert into isolevel values (10);") - cur1.execute("insert into isolevel values (20);") + with cnn2.cursor() as cur2: + cur2.execute("insert into isolevel values (10);") + cur1.execute("insert into isolevel values (20);") - cur2.execute("select count(*) from isolevel;") - self.assertEqual(1, cur2.fetchone()[0]) - cnn1.commit() - cur2.execute("select count(*) from isolevel;") - self.assertEqual(1, cur2.fetchone()[0]) + cur2.execute("select count(*) from isolevel;") + self.assertEqual(1, cur2.fetchone()[0]) + cnn1.commit() + cur2.execute("select count(*) from isolevel;") + self.assertEqual(1, cur2.fetchone()[0]) - cur1.execute("select count(*) from isolevel;") - self.assertEqual(1, cur1.fetchone()[0]) - cnn2.commit() - cur1.execute("select count(*) from isolevel;") - self.assertEqual(2, cur1.fetchone()[0]) + cur1.execute("select count(*) from isolevel;") + self.assertEqual(1, cur1.fetchone()[0]) + cnn2.commit() + cur1.execute("select count(*) from isolevel;") + self.assertEqual(2, cur1.fetchone()[0]) - cur2.execute("select count(*) from isolevel;") - self.assertEqual(2, cur2.fetchone()[0]) + cur2.execute("select count(*) from isolevel;") + self.assertEqual(2, cur2.fetchone()[0]) def test_isolation_level_closed(self): cnn = self.connect() @@ -766,99 +762,99 @@ class IsolationLevelsTestCase(ConnectingTestCase): cnn.set_isolation_level, 1) def test_setattr_isolation_level_int(self): - cur = self.conn.cursor() - self.conn.isolation_level = ext.ISOLATION_LEVEL_SERIALIZABLE - self.assertEqual(self.conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE) + with self.conn.cursor() as cur: + self.conn.isolation_level = ext.ISOLATION_LEVEL_SERIALIZABLE + self.assertEqual(self.conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE) - cur.execute("SHOW transaction_isolation;") - self.assertEqual(cur.fetchone()[0], 'serializable') - self.conn.rollback() - - self.conn.isolation_level = ext.ISOLATION_LEVEL_REPEATABLE_READ - cur.execute("SHOW transaction_isolation;") - if self.conn.info.server_version > 80000: - self.assertEqual(self.conn.isolation_level, - ext.ISOLATION_LEVEL_REPEATABLE_READ) - self.assertEqual(cur.fetchone()[0], 'repeatable read') - else: - self.assertEqual(self.conn.isolation_level, - ext.ISOLATION_LEVEL_SERIALIZABLE) + cur.execute("SHOW transaction_isolation;") self.assertEqual(cur.fetchone()[0], 'serializable') - self.conn.rollback() + self.conn.rollback() - self.conn.isolation_level = ext.ISOLATION_LEVEL_READ_COMMITTED - self.assertEqual(self.conn.isolation_level, - ext.ISOLATION_LEVEL_READ_COMMITTED) - cur.execute("SHOW transaction_isolation;") - self.assertEqual(cur.fetchone()[0], 'read committed') - self.conn.rollback() + self.conn.isolation_level = ext.ISOLATION_LEVEL_REPEATABLE_READ + cur.execute("SHOW transaction_isolation;") + if self.conn.info.server_version > 80000: + self.assertEqual(self.conn.isolation_level, + ext.ISOLATION_LEVEL_REPEATABLE_READ) + self.assertEqual(cur.fetchone()[0], 'repeatable read') + else: + self.assertEqual(self.conn.isolation_level, + ext.ISOLATION_LEVEL_SERIALIZABLE) + self.assertEqual(cur.fetchone()[0], 'serializable') + self.conn.rollback() - self.conn.isolation_level = ext.ISOLATION_LEVEL_READ_UNCOMMITTED - cur.execute("SHOW transaction_isolation;") - if self.conn.info.server_version > 80000: - self.assertEqual(self.conn.isolation_level, - ext.ISOLATION_LEVEL_READ_UNCOMMITTED) - self.assertEqual(cur.fetchone()[0], 'read uncommitted') - else: + self.conn.isolation_level = ext.ISOLATION_LEVEL_READ_COMMITTED self.assertEqual(self.conn.isolation_level, ext.ISOLATION_LEVEL_READ_COMMITTED) + cur.execute("SHOW transaction_isolation;") self.assertEqual(cur.fetchone()[0], 'read committed') - self.conn.rollback() + self.conn.rollback() - self.assertEqual(ext.ISOLATION_LEVEL_DEFAULT, None) - self.conn.isolation_level = ext.ISOLATION_LEVEL_DEFAULT - self.assertEqual(self.conn.isolation_level, None) - cur.execute("SHOW transaction_isolation;") - isol = cur.fetchone()[0] - cur.execute("SHOW default_transaction_isolation;") - self.assertEqual(cur.fetchone()[0], isol) + self.conn.isolation_level = ext.ISOLATION_LEVEL_READ_UNCOMMITTED + cur.execute("SHOW transaction_isolation;") + if self.conn.info.server_version > 80000: + self.assertEqual(self.conn.isolation_level, + ext.ISOLATION_LEVEL_READ_UNCOMMITTED) + self.assertEqual(cur.fetchone()[0], 'read uncommitted') + else: + self.assertEqual(self.conn.isolation_level, + ext.ISOLATION_LEVEL_READ_COMMITTED) + self.assertEqual(cur.fetchone()[0], 'read committed') + self.conn.rollback() + + self.assertEqual(ext.ISOLATION_LEVEL_DEFAULT, None) + self.conn.isolation_level = ext.ISOLATION_LEVEL_DEFAULT + self.assertEqual(self.conn.isolation_level, None) + cur.execute("SHOW transaction_isolation;") + isol = cur.fetchone()[0] + cur.execute("SHOW default_transaction_isolation;") + self.assertEqual(cur.fetchone()[0], isol) def test_setattr_isolation_level_str(self): - cur = self.conn.cursor() - self.conn.isolation_level = "serializable" - self.assertEqual(self.conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE) + with self.conn.cursor() as cur: + self.conn.isolation_level = "serializable" + self.assertEqual(self.conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE) - cur.execute("SHOW transaction_isolation;") - self.assertEqual(cur.fetchone()[0], 'serializable') - self.conn.rollback() - - self.conn.isolation_level = "repeatable read" - cur.execute("SHOW transaction_isolation;") - if self.conn.info.server_version > 80000: - self.assertEqual(self.conn.isolation_level, - ext.ISOLATION_LEVEL_REPEATABLE_READ) - self.assertEqual(cur.fetchone()[0], 'repeatable read') - else: - self.assertEqual(self.conn.isolation_level, - ext.ISOLATION_LEVEL_SERIALIZABLE) + cur.execute("SHOW transaction_isolation;") self.assertEqual(cur.fetchone()[0], 'serializable') - self.conn.rollback() + self.conn.rollback() - self.conn.isolation_level = "read committed" - self.assertEqual(self.conn.isolation_level, - ext.ISOLATION_LEVEL_READ_COMMITTED) - cur.execute("SHOW transaction_isolation;") - self.assertEqual(cur.fetchone()[0], 'read committed') - self.conn.rollback() + self.conn.isolation_level = "repeatable read" + cur.execute("SHOW transaction_isolation;") + if self.conn.info.server_version > 80000: + self.assertEqual(self.conn.isolation_level, + ext.ISOLATION_LEVEL_REPEATABLE_READ) + self.assertEqual(cur.fetchone()[0], 'repeatable read') + else: + self.assertEqual(self.conn.isolation_level, + ext.ISOLATION_LEVEL_SERIALIZABLE) + self.assertEqual(cur.fetchone()[0], 'serializable') + self.conn.rollback() - self.conn.isolation_level = "read uncommitted" - cur.execute("SHOW transaction_isolation;") - if self.conn.info.server_version > 80000: - self.assertEqual(self.conn.isolation_level, - ext.ISOLATION_LEVEL_READ_UNCOMMITTED) - self.assertEqual(cur.fetchone()[0], 'read uncommitted') - else: + self.conn.isolation_level = "read committed" self.assertEqual(self.conn.isolation_level, ext.ISOLATION_LEVEL_READ_COMMITTED) + cur.execute("SHOW transaction_isolation;") self.assertEqual(cur.fetchone()[0], 'read committed') - self.conn.rollback() + self.conn.rollback() - self.conn.isolation_level = "default" - self.assertEqual(self.conn.isolation_level, None) - cur.execute("SHOW transaction_isolation;") - isol = cur.fetchone()[0] - cur.execute("SHOW default_transaction_isolation;") - self.assertEqual(cur.fetchone()[0], isol) + self.conn.isolation_level = "read uncommitted" + cur.execute("SHOW transaction_isolation;") + if self.conn.info.server_version > 80000: + self.assertEqual(self.conn.isolation_level, + ext.ISOLATION_LEVEL_READ_UNCOMMITTED) + self.assertEqual(cur.fetchone()[0], 'read uncommitted') + else: + self.assertEqual(self.conn.isolation_level, + ext.ISOLATION_LEVEL_READ_COMMITTED) + self.assertEqual(cur.fetchone()[0], 'read committed') + self.conn.rollback() + + self.conn.isolation_level = "default" + self.assertEqual(self.conn.isolation_level, None) + cur.execute("SHOW transaction_isolation;") + isol = cur.fetchone()[0] + cur.execute("SHOW default_transaction_isolation;") + self.assertEqual(cur.fetchone()[0], isol) def test_setattr_isolation_level_invalid(self): self.assertRaises(ValueError, setattr, self.conn, 'isolation_level', 0) @@ -946,20 +942,20 @@ class ConnectionTwoPhaseTests(ConnectingTestCase): cnn.tpc_begin(xid) self.assertEqual(cnn.status, ext.STATUS_BEGIN) - cur = cnn.cursor() - cur.execute("insert into test_tpc values ('test_tpc_commit');") - self.assertEqual(0, self.count_xacts()) - self.assertEqual(0, self.count_test_records()) + with cnn.cursor() as cur: + cur.execute("insert into test_tpc values ('test_tpc_commit');") + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) - cnn.tpc_prepare() - self.assertEqual(cnn.status, ext.STATUS_PREPARED) - self.assertEqual(1, self.count_xacts()) - self.assertEqual(0, self.count_test_records()) + cnn.tpc_prepare() + self.assertEqual(cnn.status, ext.STATUS_PREPARED) + self.assertEqual(1, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) - cnn.tpc_commit() - self.assertEqual(cnn.status, ext.STATUS_READY) - self.assertEqual(0, self.count_xacts()) - self.assertEqual(1, self.count_test_records()) + cnn.tpc_commit() + self.assertEqual(cnn.status, ext.STATUS_READY) + self.assertEqual(0, self.count_xacts()) + self.assertEqual(1, self.count_test_records()) def test_tpc_commit_one_phase(self): cnn = self.connect() @@ -969,15 +965,15 @@ class ConnectionTwoPhaseTests(ConnectingTestCase): cnn.tpc_begin(xid) self.assertEqual(cnn.status, ext.STATUS_BEGIN) - cur = cnn.cursor() - cur.execute("insert into test_tpc values ('test_tpc_commit_1p');") - self.assertEqual(0, self.count_xacts()) - self.assertEqual(0, self.count_test_records()) + with cnn.cursor() as cur: + cur.execute("insert into test_tpc values ('test_tpc_commit_1p');") + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) - cnn.tpc_commit() - self.assertEqual(cnn.status, ext.STATUS_READY) - self.assertEqual(0, self.count_xacts()) - self.assertEqual(1, self.count_test_records()) + cnn.tpc_commit() + self.assertEqual(cnn.status, ext.STATUS_READY) + self.assertEqual(0, self.count_xacts()) + self.assertEqual(1, self.count_test_records()) def test_tpc_commit_recovered(self): cnn = self.connect() @@ -1013,20 +1009,20 @@ class ConnectionTwoPhaseTests(ConnectingTestCase): cnn.tpc_begin(xid) self.assertEqual(cnn.status, ext.STATUS_BEGIN) - cur = cnn.cursor() - cur.execute("insert into test_tpc values ('test_tpc_rollback');") - self.assertEqual(0, self.count_xacts()) - self.assertEqual(0, self.count_test_records()) + with cnn.cursor() as cur: + cur.execute("insert into test_tpc values ('test_tpc_rollback');") + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) - cnn.tpc_prepare() - self.assertEqual(cnn.status, ext.STATUS_PREPARED) - self.assertEqual(1, self.count_xacts()) - self.assertEqual(0, self.count_test_records()) + cnn.tpc_prepare() + self.assertEqual(cnn.status, ext.STATUS_PREPARED) + self.assertEqual(1, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) - cnn.tpc_rollback() - self.assertEqual(cnn.status, ext.STATUS_READY) - self.assertEqual(0, self.count_xacts()) - self.assertEqual(0, self.count_test_records()) + cnn.tpc_rollback() + self.assertEqual(cnn.status, ext.STATUS_READY) + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) def test_tpc_rollback_one_phase(self): cnn = self.connect() @@ -1036,15 +1032,15 @@ class ConnectionTwoPhaseTests(ConnectingTestCase): cnn.tpc_begin(xid) self.assertEqual(cnn.status, ext.STATUS_BEGIN) - cur = cnn.cursor() - cur.execute("insert into test_tpc values ('test_tpc_rollback_1p');") - self.assertEqual(0, self.count_xacts()) - self.assertEqual(0, self.count_test_records()) + with cnn.cursor() as cur: + cur.execute("insert into test_tpc values ('test_tpc_rollback_1p');") + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) - cnn.tpc_rollback() - self.assertEqual(cnn.status, ext.STATUS_READY) - self.assertEqual(0, self.count_xacts()) - self.assertEqual(0, self.count_test_records()) + cnn.tpc_rollback() + self.assertEqual(cnn.status, ext.STATUS_READY) + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) def test_tpc_rollback_recovered(self): cnn = self.connect() @@ -1054,23 +1050,23 @@ class ConnectionTwoPhaseTests(ConnectingTestCase): cnn.tpc_begin(xid) self.assertEqual(cnn.status, ext.STATUS_BEGIN) - cur = cnn.cursor() - cur.execute("insert into test_tpc values ('test_tpc_commit_rec');") - self.assertEqual(0, self.count_xacts()) - self.assertEqual(0, self.count_test_records()) + with cnn.cursor() as cur: + cur.execute("insert into test_tpc values ('test_tpc_commit_rec');") + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) - cnn.tpc_prepare() - cnn.close() - self.assertEqual(1, self.count_xacts()) - self.assertEqual(0, self.count_test_records()) + cnn.tpc_prepare() + cnn.close() + self.assertEqual(1, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) - cnn = self.connect() - xid = cnn.xid(1, "gtrid", "bqual") - cnn.tpc_rollback(xid) + cnn = self.connect() + xid = cnn.xid(1, "gtrid", "bqual") + cnn.tpc_rollback(xid) - self.assertEqual(cnn.status, ext.STATUS_READY) - self.assertEqual(0, self.count_xacts()) - self.assertEqual(0, self.count_test_records()) + self.assertEqual(cnn.status, ext.STATUS_READY) + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) def test_status_after_recover(self): cnn = self.connect() @@ -1078,28 +1074,28 @@ class ConnectionTwoPhaseTests(ConnectingTestCase): cnn.tpc_recover() self.assertEqual(ext.STATUS_READY, cnn.status) - cur = cnn.cursor() - cur.execute("select 1") - self.assertEqual(ext.STATUS_BEGIN, cnn.status) - cnn.tpc_recover() - self.assertEqual(ext.STATUS_BEGIN, cnn.status) + with cnn.cursor() as cur: + cur.execute("select 1") + self.assertEqual(ext.STATUS_BEGIN, cnn.status) + cnn.tpc_recover() + self.assertEqual(ext.STATUS_BEGIN, cnn.status) def test_recovered_xids(self): # insert a few test xns cnn = self.connect() cnn.set_isolation_level(0) - cur = cnn.cursor() - cur.execute("begin; prepare transaction '1-foo';") - cur.execute("begin; prepare transaction '2-bar';") + with cnn.cursor() as cur: + cur.execute("begin; prepare transaction '1-foo';") + cur.execute("begin; prepare transaction '2-bar';") - # read the values to return - cur.execute(""" - select gid, prepared, owner, database - from pg_prepared_xacts - where database = %s;""", - (dbname,)) - okvals = cur.fetchall() - okvals.sort() + # read the values to return + cur.execute(""" + select gid, prepared, owner, database + from pg_prepared_xacts + where database = %s;""", + (dbname,)) + okvals = cur.fetchall() + okvals.sort() cnn = self.connect() xids = cnn.tpc_recover() @@ -1121,10 +1117,10 @@ class ConnectionTwoPhaseTests(ConnectingTestCase): cnn.tpc_prepare() cnn = self.connect() - cur = cnn.cursor() - cur.execute("select gid from pg_prepared_xacts where database = %s;", - (dbname,)) - self.assertEqual('42_Z3RyaWQ=_YnF1YWw=', cur.fetchone()[0]) + with cnn.cursor() as cur: + cur.execute("select gid from pg_prepared_xacts where database = %s;", + (dbname,)) + self.assertEqual('42_Z3RyaWQ=_YnF1YWw=', cur.fetchone()[0]) @slow def test_xid_roundtrip(self): @@ -1248,71 +1244,71 @@ class TransactionControlTests(ConnectingTestCase): ext.ISOLATION_LEVEL_SERIALIZABLE) def test_not_in_transaction(self): - cur = self.conn.cursor() - cur.execute("select 1") - self.assertRaises(psycopg2.ProgrammingError, - self.conn.set_session, - ext.ISOLATION_LEVEL_SERIALIZABLE) + with self.conn.cursor() as cur: + cur.execute("select 1") + self.assertRaises(psycopg2.ProgrammingError, + self.conn.set_session, + ext.ISOLATION_LEVEL_SERIALIZABLE) def test_set_isolation_level(self): - cur = self.conn.cursor() - self.conn.set_session( - ext.ISOLATION_LEVEL_SERIALIZABLE) - cur.execute("SHOW transaction_isolation;") - self.assertEqual(cur.fetchone()[0], 'serializable') - self.conn.rollback() - - self.conn.set_session( - ext.ISOLATION_LEVEL_REPEATABLE_READ) - cur.execute("SHOW transaction_isolation;") - if self.conn.info.server_version > 80000: - self.assertEqual(cur.fetchone()[0], 'repeatable read') - else: + with self.conn.cursor() as cur: + self.conn.set_session( + ext.ISOLATION_LEVEL_SERIALIZABLE) + cur.execute("SHOW transaction_isolation;") self.assertEqual(cur.fetchone()[0], 'serializable') - self.conn.rollback() + self.conn.rollback() - self.conn.set_session( - isolation_level=ext.ISOLATION_LEVEL_READ_COMMITTED) - cur.execute("SHOW transaction_isolation;") - self.assertEqual(cur.fetchone()[0], 'read committed') - self.conn.rollback() + self.conn.set_session( + ext.ISOLATION_LEVEL_REPEATABLE_READ) + cur.execute("SHOW transaction_isolation;") + if self.conn.info.server_version > 80000: + self.assertEqual(cur.fetchone()[0], 'repeatable read') + else: + self.assertEqual(cur.fetchone()[0], 'serializable') + self.conn.rollback() - self.conn.set_session( - isolation_level=ext.ISOLATION_LEVEL_READ_UNCOMMITTED) - cur.execute("SHOW transaction_isolation;") - if self.conn.info.server_version > 80000: - self.assertEqual(cur.fetchone()[0], 'read uncommitted') - else: + self.conn.set_session( + isolation_level=ext.ISOLATION_LEVEL_READ_COMMITTED) + cur.execute("SHOW transaction_isolation;") self.assertEqual(cur.fetchone()[0], 'read committed') - self.conn.rollback() + self.conn.rollback() + + self.conn.set_session( + isolation_level=ext.ISOLATION_LEVEL_READ_UNCOMMITTED) + cur.execute("SHOW transaction_isolation;") + if self.conn.info.server_version > 80000: + self.assertEqual(cur.fetchone()[0], 'read uncommitted') + else: + self.assertEqual(cur.fetchone()[0], 'read committed') + self.conn.rollback() def test_set_isolation_level_str(self): - cur = self.conn.cursor() - self.conn.set_session("serializable") - cur.execute("SHOW transaction_isolation;") - self.assertEqual(cur.fetchone()[0], 'serializable') - self.conn.rollback() - - self.conn.set_session("repeatable read") - cur.execute("SHOW transaction_isolation;") - if self.conn.info.server_version > 80000: - self.assertEqual(cur.fetchone()[0], 'repeatable read') - else: + with self.conn.cursor() as cur: + self.conn.set_session("serializable") + cur.execute("SHOW transaction_isolation;") self.assertEqual(cur.fetchone()[0], 'serializable') - self.conn.rollback() + self.conn.rollback() - self.conn.set_session("read committed") - cur.execute("SHOW transaction_isolation;") - self.assertEqual(cur.fetchone()[0], 'read committed') - self.conn.rollback() + self.conn.set_session("repeatable read") + cur.execute("SHOW transaction_isolation;") + if self.conn.info.server_version > 80000: + self.assertEqual(cur.fetchone()[0], 'repeatable read') + else: + self.assertEqual(cur.fetchone()[0], 'serializable') + self.conn.rollback() - self.conn.set_session("read uncommitted") - cur.execute("SHOW transaction_isolation;") - if self.conn.info.server_version > 80000: - self.assertEqual(cur.fetchone()[0], 'read uncommitted') - else: + self.conn.set_session("read committed") + cur.execute("SHOW transaction_isolation;") self.assertEqual(cur.fetchone()[0], 'read committed') - self.conn.rollback() + self.conn.rollback() + + self.conn.set_session("read uncommitted") + cur.execute("SHOW transaction_isolation;") + if self.conn.info.server_version > 80000: + self.assertEqual(cur.fetchone()[0], 'read uncommitted') + else: + self.assertEqual(cur.fetchone()[0], 'read committed') + self.conn.rollback() def test_bad_isolation_level(self): self.assertRaises(ValueError, self.conn.set_session, 0) @@ -1322,87 +1318,87 @@ class TransactionControlTests(ConnectingTestCase): def test_set_read_only(self): self.assert_(self.conn.readonly is None) - cur = self.conn.cursor() - self.conn.set_session(readonly=True) - self.assert_(self.conn.readonly is True) - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'on') - self.conn.rollback() - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'on') - self.conn.rollback() + with self.conn.cursor() as cur: + self.conn.set_session(readonly=True) + self.assert_(self.conn.readonly is True) + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'on') + self.conn.rollback() + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'on') + self.conn.rollback() - self.conn.set_session(readonly=False) - self.assert_(self.conn.readonly is False) - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'off') - self.conn.rollback() + self.conn.set_session(readonly=False) + self.assert_(self.conn.readonly is False) + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'off') + self.conn.rollback() def test_setattr_read_only(self): - cur = self.conn.cursor() - self.conn.readonly = True - self.assert_(self.conn.readonly is True) - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'on') - self.assertRaises(self.conn.ProgrammingError, - setattr, self.conn, 'readonly', False) - self.assert_(self.conn.readonly is True) - self.conn.rollback() - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'on') - self.conn.rollback() + with self.conn.cursor() as cur: + self.conn.readonly = True + self.assert_(self.conn.readonly is True) + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'on') + self.assertRaises(self.conn.ProgrammingError, + setattr, self.conn, 'readonly', False) + self.assert_(self.conn.readonly is True) + self.conn.rollback() + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'on') + self.conn.rollback() - cur = self.conn.cursor() - self.conn.readonly = None - self.assert_(self.conn.readonly is None) - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'off') # assume defined by server - self.conn.rollback() + with self.conn.cursor() as cur: + self.conn.readonly = None + self.assert_(self.conn.readonly is None) + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'off') # assume defined by server + self.conn.rollback() - self.conn.readonly = False - self.assert_(self.conn.readonly is False) - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'off') - self.conn.rollback() + self.conn.readonly = False + self.assert_(self.conn.readonly is False) + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'off') + self.conn.rollback() def test_set_default(self): - cur = self.conn.cursor() - cur.execute("SHOW transaction_isolation;") - isolevel = cur.fetchone()[0] - cur.execute("SHOW transaction_read_only;") - readonly = cur.fetchone()[0] - self.conn.rollback() + with self.conn.cursor() as cur: + cur.execute("SHOW transaction_isolation;") + isolevel = cur.fetchone()[0] + cur.execute("SHOW transaction_read_only;") + readonly = cur.fetchone()[0] + self.conn.rollback() - self.conn.set_session(isolation_level='serializable', readonly=True) - self.conn.set_session(isolation_level='default', readonly='default') + self.conn.set_session(isolation_level='serializable', readonly=True) + self.conn.set_session(isolation_level='default', readonly='default') - cur.execute("SHOW transaction_isolation;") - self.assertEqual(cur.fetchone()[0], isolevel) - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], readonly) + cur.execute("SHOW transaction_isolation;") + self.assertEqual(cur.fetchone()[0], isolevel) + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], readonly) @skip_before_postgres(9, 1) def test_set_deferrable(self): self.assert_(self.conn.deferrable is None) - cur = self.conn.cursor() - self.conn.set_session(readonly=True, deferrable=True) - self.assert_(self.conn.deferrable is True) - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'on') - cur.execute("SHOW transaction_deferrable;") - self.assertEqual(cur.fetchone()[0], 'on') - self.conn.rollback() - cur.execute("SHOW transaction_deferrable;") - self.assertEqual(cur.fetchone()[0], 'on') - self.conn.rollback() + with self.conn.cursor() as cur: + self.conn.set_session(readonly=True, deferrable=True) + self.assert_(self.conn.deferrable is True) + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'on') + cur.execute("SHOW transaction_deferrable;") + self.assertEqual(cur.fetchone()[0], 'on') + self.conn.rollback() + cur.execute("SHOW transaction_deferrable;") + self.assertEqual(cur.fetchone()[0], 'on') + self.conn.rollback() - self.conn.set_session(deferrable=False) - self.assert_(self.conn.deferrable is False) - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'on') - cur.execute("SHOW transaction_deferrable;") - self.assertEqual(cur.fetchone()[0], 'off') - self.conn.rollback() + self.conn.set_session(deferrable=False) + self.assert_(self.conn.deferrable is False) + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'on') + cur.execute("SHOW transaction_deferrable;") + self.assertEqual(cur.fetchone()[0], 'off') + self.conn.rollback() @skip_after_postgres(9, 1) def test_set_deferrable_error(self): @@ -1413,49 +1409,49 @@ class TransactionControlTests(ConnectingTestCase): @skip_before_postgres(9, 1) def test_setattr_deferrable(self): - cur = self.conn.cursor() - self.conn.deferrable = True - self.assert_(self.conn.deferrable is True) - cur.execute("SHOW transaction_deferrable;") - self.assertEqual(cur.fetchone()[0], 'on') - self.assertRaises(self.conn.ProgrammingError, - setattr, self.conn, 'deferrable', False) - self.assert_(self.conn.deferrable is True) - self.conn.rollback() - cur.execute("SHOW transaction_deferrable;") - self.assertEqual(cur.fetchone()[0], 'on') - self.conn.rollback() + with self.conn.cursor() as cur: + self.conn.deferrable = True + self.assert_(self.conn.deferrable is True) + cur.execute("SHOW transaction_deferrable;") + self.assertEqual(cur.fetchone()[0], 'on') + self.assertRaises(self.conn.ProgrammingError, + setattr, self.conn, 'deferrable', False) + self.assert_(self.conn.deferrable is True) + self.conn.rollback() + cur.execute("SHOW transaction_deferrable;") + self.assertEqual(cur.fetchone()[0], 'on') + self.conn.rollback() - cur = self.conn.cursor() - self.conn.deferrable = None - self.assert_(self.conn.deferrable is None) - cur.execute("SHOW transaction_deferrable;") - self.assertEqual(cur.fetchone()[0], 'off') # assume defined by server - self.conn.rollback() + with self.conn.cursor() as cur: + self.conn.deferrable = None + self.assert_(self.conn.deferrable is None) + cur.execute("SHOW transaction_deferrable;") + self.assertEqual(cur.fetchone()[0], 'off') # assume defined by server + self.conn.rollback() - self.conn.deferrable = False - self.assert_(self.conn.deferrable is False) - cur.execute("SHOW transaction_deferrable;") - self.assertEqual(cur.fetchone()[0], 'off') - self.conn.rollback() + self.conn.deferrable = False + self.assert_(self.conn.deferrable is False) + cur.execute("SHOW transaction_deferrable;") + self.assertEqual(cur.fetchone()[0], 'off') + self.conn.rollback() def test_mixing_session_attribs(self): - cur = self.conn.cursor() - self.conn.autocommit = True - self.conn.readonly = True + with self.conn.cursor() as cur: + self.conn.autocommit = True + self.conn.readonly = True - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'on') + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'on') - cur.execute("SHOW default_transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'on') + cur.execute("SHOW default_transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'on') - self.conn.autocommit = False - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'on') + self.conn.autocommit = False + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'on') - cur.execute("SHOW default_transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'off') + cur.execute("SHOW default_transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'off') def test_idempotence_check(self): self.conn.autocommit = False @@ -1463,9 +1459,9 @@ class TransactionControlTests(ConnectingTestCase): self.conn.autocommit = True self.conn.readonly = True - cur = self.conn.cursor() - cur.execute("SHOW transaction_read_only") - self.assertEqual(cur.fetchone()[0], 'on') + with self.conn.cursor() as cur: + cur.execute("SHOW transaction_read_only") + self.assertEqual(cur.fetchone()[0], 'on') class TestEncryptPassword(ConnectingTestCase): @@ -1486,26 +1482,26 @@ class TestEncryptPassword(ConnectingTestCase): @skip_before_libpq(10) @skip_before_postgres(10) def test_encrypt_server(self): - cur = self.conn.cursor() - cur.execute("SHOW password_encryption;") - server_encryption_algorithm = cur.fetchone()[0] + with self.conn.cursor() as cur: + cur.execute("SHOW password_encryption;") + server_encryption_algorithm = cur.fetchone()[0] - enc_password = ext.encrypt_password( - 'psycopg2', 'ashesh', self.conn) + enc_password = ext.encrypt_password( + 'psycopg2', 'ashesh', self.conn) + + if server_encryption_algorithm == 'md5': + self.assertEqual( + enc_password, 'md594839d658c28a357126f105b9cb14cfc') + elif server_encryption_algorithm == 'scram-sha-256': + self.assertEqual(enc_password[:14], 'SCRAM-SHA-256$') - if server_encryption_algorithm == 'md5': self.assertEqual( - enc_password, 'md594839d658c28a357126f105b9cb14cfc') - elif server_encryption_algorithm == 'scram-sha-256': - self.assertEqual(enc_password[:14], 'SCRAM-SHA-256$') + ext.encrypt_password( + 'psycopg2', 'ashesh', self.conn, 'scram-sha-256' + )[:14], 'SCRAM-SHA-256$') - self.assertEqual( - ext.encrypt_password( - 'psycopg2', 'ashesh', self.conn, 'scram-sha-256' - )[:14], 'SCRAM-SHA-256$') - - self.assertRaises(psycopg2.ProgrammingError, - ext.encrypt_password, 'psycopg2', 'ashesh', self.conn, 'abc') + self.assertRaises(psycopg2.ProgrammingError, + ext.encrypt_password, 'psycopg2', 'ashesh', self.conn, 'abc') def test_encrypt_md5(self): self.assertEqual( @@ -1568,16 +1564,16 @@ class AutocommitTests(ConnectingTestCase): self.assertEqual(self.conn.info.transaction_status, ext.TRANSACTION_STATUS_IDLE) - cur = self.conn.cursor() - cur.execute('select 1;') - self.assertEqual(self.conn.status, ext.STATUS_BEGIN) - self.assertEqual(self.conn.info.transaction_status, - ext.TRANSACTION_STATUS_INTRANS) + with self.conn.cursor() as cur: + cur.execute('select 1;') + self.assertEqual(self.conn.status, ext.STATUS_BEGIN) + self.assertEqual(self.conn.info.transaction_status, + ext.TRANSACTION_STATUS_INTRANS) - self.conn.rollback() - self.assertEqual(self.conn.status, ext.STATUS_READY) - self.assertEqual(self.conn.info.transaction_status, - ext.TRANSACTION_STATUS_IDLE) + self.conn.rollback() + self.assertEqual(self.conn.status, ext.STATUS_READY) + self.assertEqual(self.conn.info.transaction_status, + ext.TRANSACTION_STATUS_IDLE) def test_set_autocommit(self): self.conn.autocommit = True @@ -1586,28 +1582,28 @@ class AutocommitTests(ConnectingTestCase): self.assertEqual(self.conn.info.transaction_status, ext.TRANSACTION_STATUS_IDLE) - cur = self.conn.cursor() - cur.execute('select 1;') - self.assertEqual(self.conn.status, ext.STATUS_READY) - self.assertEqual(self.conn.info.transaction_status, - ext.TRANSACTION_STATUS_IDLE) + with self.conn.cursor() as cur: + cur.execute('select 1;') + self.assertEqual(self.conn.status, ext.STATUS_READY) + self.assertEqual(self.conn.info.transaction_status, + ext.TRANSACTION_STATUS_IDLE) - self.conn.autocommit = False - self.assert_(not self.conn.autocommit) - self.assertEqual(self.conn.status, ext.STATUS_READY) - self.assertEqual(self.conn.info.transaction_status, - ext.TRANSACTION_STATUS_IDLE) + self.conn.autocommit = False + self.assert_(not self.conn.autocommit) + self.assertEqual(self.conn.status, ext.STATUS_READY) + self.assertEqual(self.conn.info.transaction_status, + ext.TRANSACTION_STATUS_IDLE) - cur.execute('select 1;') - self.assertEqual(self.conn.status, ext.STATUS_BEGIN) - self.assertEqual(self.conn.info.transaction_status, - ext.TRANSACTION_STATUS_INTRANS) + cur.execute('select 1;') + self.assertEqual(self.conn.status, ext.STATUS_BEGIN) + self.assertEqual(self.conn.info.transaction_status, + ext.TRANSACTION_STATUS_INTRANS) def test_set_intrans_error(self): - cur = self.conn.cursor() - cur.execute('select 1;') - self.assertRaises(psycopg2.ProgrammingError, - setattr, self.conn, 'autocommit', True) + with self.conn.cursor() as cur: + cur.execute('select 1;') + self.assertRaises(psycopg2.ProgrammingError, + setattr, self.conn, 'autocommit', True) def test_set_session_autocommit(self): self.conn.set_session(autocommit=True) @@ -1616,34 +1612,34 @@ class AutocommitTests(ConnectingTestCase): self.assertEqual(self.conn.info.transaction_status, ext.TRANSACTION_STATUS_IDLE) - cur = self.conn.cursor() - cur.execute('select 1;') - self.assertEqual(self.conn.status, ext.STATUS_READY) - self.assertEqual(self.conn.info.transaction_status, - ext.TRANSACTION_STATUS_IDLE) + with self.conn.cursor() as cur: + cur.execute('select 1;') + self.assertEqual(self.conn.status, ext.STATUS_READY) + self.assertEqual(self.conn.info.transaction_status, + ext.TRANSACTION_STATUS_IDLE) - self.conn.set_session(autocommit=False) - self.assert_(not self.conn.autocommit) - self.assertEqual(self.conn.status, ext.STATUS_READY) - self.assertEqual(self.conn.info.transaction_status, - ext.TRANSACTION_STATUS_IDLE) + self.conn.set_session(autocommit=False) + self.assert_(not self.conn.autocommit) + self.assertEqual(self.conn.status, ext.STATUS_READY) + self.assertEqual(self.conn.info.transaction_status, + ext.TRANSACTION_STATUS_IDLE) - cur.execute('select 1;') - self.assertEqual(self.conn.status, ext.STATUS_BEGIN) - self.assertEqual(self.conn.info.transaction_status, - ext.TRANSACTION_STATUS_INTRANS) - self.conn.rollback() + cur.execute('select 1;') + self.assertEqual(self.conn.status, ext.STATUS_BEGIN) + self.assertEqual(self.conn.info.transaction_status, + ext.TRANSACTION_STATUS_INTRANS) + self.conn.rollback() - self.conn.set_session('serializable', readonly=True, autocommit=True) - self.assert_(self.conn.autocommit) - cur.execute('select 1;') - self.assertEqual(self.conn.status, ext.STATUS_READY) - self.assertEqual(self.conn.info.transaction_status, - ext.TRANSACTION_STATUS_IDLE) - cur.execute("SHOW transaction_isolation;") - self.assertEqual(cur.fetchone()[0], 'serializable') - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'on') + self.conn.set_session('serializable', readonly=True, autocommit=True) + self.assert_(self.conn.autocommit) + cur.execute('select 1;') + self.assertEqual(self.conn.status, ext.STATUS_READY) + self.assertEqual(self.conn.info.transaction_status, + ext.TRANSACTION_STATUS_IDLE) + cur.execute("SHOW transaction_isolation;") + self.assertEqual(cur.fetchone()[0], 'serializable') + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'on') class PasswordLeakTestCase(ConnectingTestCase): @@ -1762,10 +1758,10 @@ class TestConnectionInfo(ConnectingTestCase): self.assert_(self.bconn.info.dbname is None) def test_user(self): - cur = self.conn.cursor() - cur.execute("select user") - self.assertEqual(self.conn.info.user, cur.fetchone()[0]) - self.assert_(self.bconn.info.user is None) + with self.conn.cursor() as cur: + cur.execute("select user") + self.assertEqual(self.conn.info.user, cur.fetchone()[0]) + self.assert_(self.bconn.info.user is None) def test_password(self): self.assert_(isinstance(self.conn.info.password, str)) @@ -1801,69 +1797,69 @@ class TestConnectionInfo(ConnectingTestCase): def test_transaction_status(self): self.assertEqual(self.conn.info.transaction_status, 0) - cur = self.conn.cursor() - cur.execute("select 1") - self.assertEqual(self.conn.info.transaction_status, 2) - self.assertEqual(self.bconn.info.transaction_status, 4) + with self.conn.cursor() as cur: + cur.execute("select 1") + self.assertEqual(self.conn.info.transaction_status, 2) + self.assertEqual(self.bconn.info.transaction_status, 4) def test_parameter_status(self): - cur = self.conn.cursor() - try: - cur.execute("show server_version") - except psycopg2.DatabaseError: - self.assertIsInstance( - self.conn.info.parameter_status('server_version'), str) - else: - self.assertEqual( - self.conn.info.parameter_status('server_version'), - cur.fetchone()[0]) + with self.conn.cursor() as cur: + try: + cur.execute("show server_version") + except psycopg2.DatabaseError: + self.assertIsInstance( + self.conn.info.parameter_status('server_version'), str) + else: + self.assertEqual( + self.conn.info.parameter_status('server_version'), + cur.fetchone()[0]) - self.assertIsNone(self.conn.info.parameter_status('wat')) - self.assertIsNone(self.bconn.info.parameter_status('server_version')) + self.assertIsNone(self.conn.info.parameter_status('wat')) + self.assertIsNone(self.bconn.info.parameter_status('server_version')) def test_protocol_version(self): self.assertEqual(self.conn.info.protocol_version, 3) self.assertEqual(self.bconn.info.protocol_version, 0) def test_server_version(self): - cur = self.conn.cursor() - try: - cur.execute("show server_version_num") - except psycopg2.DatabaseError: - self.assert_(isinstance(self.conn.info.server_version, int)) - else: - self.assertEqual( - self.conn.info.server_version, int(cur.fetchone()[0])) + with self.conn.cursor() as cur: + try: + cur.execute("show server_version_num") + except psycopg2.DatabaseError: + self.assert_(isinstance(self.conn.info.server_version, int)) + else: + self.assertEqual( + self.conn.info.server_version, int(cur.fetchone()[0])) - self.assertEqual(self.bconn.info.server_version, 0) + self.assertEqual(self.bconn.info.server_version, 0) def test_error_message(self): self.assertIsNone(self.conn.info.error_message) self.assertIsNotNone(self.bconn.info.error_message) - cur = self.conn.cursor() - try: - cur.execute("select 1 from nosuchtable") - except psycopg2.DatabaseError: - pass + with self.conn.cursor() as cur: + try: + cur.execute("select 1 from nosuchtable") + except psycopg2.DatabaseError: + pass - self.assert_('nosuchtable' in self.conn.info.error_message) + self.assert_('nosuchtable' in self.conn.info.error_message) def test_socket(self): self.assert_(self.conn.info.socket >= 0) self.assert_(self.bconn.info.socket < 0) def test_backend_pid(self): - cur = self.conn.cursor() - try: - cur.execute("select pg_backend_pid()") - except psycopg2.DatabaseError: - self.assert_(self.conn.info.backend_pid > 0) - else: - self.assertEqual( - self.conn.info.backend_pid, int(cur.fetchone()[0])) + with self.conn.cursor() as cur: + try: + cur.execute("select pg_backend_pid()") + except psycopg2.DatabaseError: + self.assert_(self.conn.info.backend_pid > 0) + else: + self.assertEqual( + self.conn.info.backend_pid, int(cur.fetchone()[0])) - self.assert_(self.bconn.info.backend_pid == 0) + self.assert_(self.bconn.info.backend_pid == 0) def test_needs_password(self): self.assertIs(self.conn.info.needs_password, False) diff --git a/tests/test_copy.py b/tests/test_copy.py index 4fdf1641..37a2416b 100755 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -66,12 +66,12 @@ class CopyTests(ConnectingTestCase): self._create_temp_table() def _create_temp_table(self): - curs = self.conn.cursor() - curs.execute(''' - CREATE TEMPORARY TABLE tcopy ( - id serial PRIMARY KEY, - data text - )''') + with self.conn.cursor() as curs: + curs.execute(''' + CREATE TEMPORARY TABLE tcopy ( + id serial PRIMARY KEY, + data text + )''') @slow def test_copy_from(self): @@ -92,31 +92,31 @@ class CopyTests(ConnectingTestCase): curs.close() def test_copy_from_cols(self): - curs = self.conn.cursor() - f = StringIO() - for i in range(10): - f.write("%s\n" % (i,)) + with self.conn.cursor() as curs: + f = StringIO() + for i in range(10): + f.write("%s\n" % (i,)) - f.seek(0) - curs.copy_from(MinimalRead(f), "tcopy", columns=['id']) + f.seek(0) + curs.copy_from(MinimalRead(f), "tcopy", columns=['id']) - curs.execute("select * from tcopy order by id") - self.assertEqual([(i, None) for i in range(10)], curs.fetchall()) + curs.execute("select * from tcopy order by id") + self.assertEqual([(i, None) for i in range(10)], curs.fetchall()) def test_copy_from_cols_err(self): - curs = self.conn.cursor() - f = StringIO() - for i in range(10): - f.write("%s\n" % (i,)) + with self.conn.cursor() as curs: + f = StringIO() + for i in range(10): + f.write("%s\n" % (i,)) - f.seek(0) + f.seek(0) - def cols(): - raise ZeroDivisionError() - yield 'id' + def cols(): + raise ZeroDivisionError() + yield 'id' - self.assertRaises(ZeroDivisionError, - curs.copy_from, MinimalRead(f), "tcopy", columns=cols()) + self.assertRaises(ZeroDivisionError, + curs.copy_from, MinimalRead(f), "tcopy", columns=cols()) @slow def test_copy_to(self): @@ -140,14 +140,14 @@ class CopyTests(ConnectingTestCase): + list(range(160, 256))).decode('latin1') about = abin.replace('\\', '\\\\') - curs = self.conn.cursor() - curs.execute('insert into tcopy values (%s, %s)', - (42, abin)) + with self.conn.cursor() as curs: + curs.execute('insert into tcopy values (%s, %s)', + (42, abin)) - f = io.StringIO() - curs.copy_to(f, 'tcopy', columns=('data',)) - f.seek(0) - self.assertEqual(f.readline().rstrip(), about) + f = io.StringIO() + curs.copy_to(f, 'tcopy', columns=('data',)) + f.seek(0) + self.assertEqual(f.readline().rstrip(), about) def test_copy_bytes(self): self.conn.set_client_encoding('latin1') @@ -161,14 +161,14 @@ class CopyTests(ConnectingTestCase): + list(range(160, 255))).decode('latin1') about = abin.replace('\\', '\\\\').encode('latin1') - curs = self.conn.cursor() - curs.execute('insert into tcopy values (%s, %s)', - (42, abin)) + with self.conn.cursor() as curs: + curs.execute('insert into tcopy values (%s, %s)', + (42, abin)) - f = io.BytesIO() - curs.copy_to(f, 'tcopy', columns=('data',)) - f.seek(0) - self.assertEqual(f.readline().rstrip(), about) + f = io.BytesIO() + curs.copy_to(f, 'tcopy', columns=('data',)) + f.seek(0) + self.assertEqual(f.readline().rstrip(), about) def test_copy_expert_textiobase(self): self.conn.set_client_encoding('latin1') @@ -188,35 +188,35 @@ class CopyTests(ConnectingTestCase): f.write(about) f.seek(0) - curs = self.conn.cursor() - psycopg2.extensions.register_type( - psycopg2.extensions.UNICODE, curs) + with self.conn.cursor() as curs: + psycopg2.extensions.register_type( + psycopg2.extensions.UNICODE, curs) - curs.copy_expert('COPY tcopy (data) FROM STDIN', f) - curs.execute("select data from tcopy;") - self.assertEqual(curs.fetchone()[0], abin) + curs.copy_expert('COPY tcopy (data) FROM STDIN', f) + curs.execute("select data from tcopy;") + self.assertEqual(curs.fetchone()[0], abin) - f = io.StringIO() - curs.copy_expert('COPY tcopy (data) TO STDOUT', f) - f.seek(0) - self.assertEqual(f.readline().rstrip(), about) + f = io.StringIO() + curs.copy_expert('COPY tcopy (data) TO STDOUT', f) + f.seek(0) + self.assertEqual(f.readline().rstrip(), about) - # same tests with setting size - f = io.StringIO() - f.write(about) - f.seek(0) - exp_size = 123 - # hack here to leave file as is, only check size when reading - real_read = f.read + # same tests with setting size + f = io.StringIO() + f.write(about) + f.seek(0) + exp_size = 123 + # hack here to leave file as is, only check size when reading + real_read = f.read - def read(_size, f=f, exp_size=exp_size): - self.assertEqual(_size, exp_size) - return real_read(_size) + def read(_size, f=f, exp_size=exp_size): + self.assertEqual(_size, exp_size) + return real_read(_size) - f.read = read - curs.copy_expert('COPY tcopy (data) FROM STDIN', f, size=exp_size) - curs.execute("select data from tcopy;") - self.assertEqual(curs.fetchone()[0], abin) + f.read = read + curs.copy_expert('COPY tcopy (data) FROM STDIN', f, size=exp_size) + curs.execute("select data from tcopy;") + self.assertEqual(curs.fetchone()[0], abin) def _copy_from(self, curs, nrecs, srec, copykw): f = StringIO() @@ -254,56 +254,54 @@ class CopyTests(ConnectingTestCase): pass f = Whatever() - curs = self.conn.cursor() - self.assertRaises(TypeError, - curs.copy_expert, 'COPY tcopy (data) FROM STDIN', f) + with self.conn.cursor() as curs: + self.assertRaises(TypeError, + curs.copy_expert, 'COPY tcopy (data) FROM STDIN', f) def test_copy_no_column_limit(self): cols = ["c%050d" % i for i in range(200)] - curs = self.conn.cursor() - curs.execute('CREATE TEMPORARY TABLE manycols (%s)' % ',\n'.join( - ["%s int" % c for c in cols])) - curs.execute("INSERT INTO manycols DEFAULT VALUES") + with self.conn.cursor() as curs: + curs.execute('CREATE TEMPORARY TABLE manycols (%s)' % ',\n'.join( + ["%s int" % c for c in cols])) + curs.execute("INSERT INTO manycols DEFAULT VALUES") - f = StringIO() - curs.copy_to(f, "manycols", columns=cols) - f.seek(0) - self.assertEqual(f.read().split(), ['\\N'] * len(cols)) + f = StringIO() + curs.copy_to(f, "manycols", columns=cols) + f.seek(0) + self.assertEqual(f.read().split(), ['\\N'] * len(cols)) - f.seek(0) - curs.copy_from(f, "manycols", columns=cols) - curs.execute("select count(*) from manycols;") - self.assertEqual(curs.fetchone()[0], 2) + f.seek(0) + curs.copy_from(f, "manycols", columns=cols) + curs.execute("select count(*) from manycols;") + self.assertEqual(curs.fetchone()[0], 2) @skip_before_postgres(8, 2) # they don't send the count def test_copy_rowcount(self): - curs = self.conn.cursor() + with self.conn.cursor() as curs: + curs.copy_from(StringIO('aaa\nbbb\nccc\n'), 'tcopy', columns=['data']) + self.assertEqual(curs.rowcount, 3) - curs.copy_from(StringIO('aaa\nbbb\nccc\n'), 'tcopy', columns=['data']) - self.assertEqual(curs.rowcount, 3) + curs.copy_expert( + "copy tcopy (data) from stdin", + StringIO('ddd\neee\n')) + self.assertEqual(curs.rowcount, 2) - curs.copy_expert( - "copy tcopy (data) from stdin", - StringIO('ddd\neee\n')) - self.assertEqual(curs.rowcount, 2) + curs.copy_to(StringIO(), "tcopy") + self.assertEqual(curs.rowcount, 5) - curs.copy_to(StringIO(), "tcopy") - self.assertEqual(curs.rowcount, 5) - - curs.execute("insert into tcopy (data) values ('fff')") - curs.copy_expert("copy tcopy to stdout", StringIO()) - self.assertEqual(curs.rowcount, 6) + curs.execute("insert into tcopy (data) values ('fff')") + curs.copy_expert("copy tcopy to stdout", StringIO()) + self.assertEqual(curs.rowcount, 6) def test_copy_rowcount_error(self): - curs = self.conn.cursor() + with self.conn.cursor() as curs: + curs.execute("insert into tcopy (data) values ('fff')") + self.assertEqual(curs.rowcount, 1) - curs.execute("insert into tcopy (data) values ('fff')") - self.assertEqual(curs.rowcount, 1) - - self.assertRaises(psycopg2.DataError, - curs.copy_from, StringIO('aaa\nbbb\nccc\n'), 'tcopy') - self.assertEqual(curs.rowcount, -1) + self.assertRaises(psycopg2.DataError, + curs.copy_from, StringIO('aaa\nbbb\nccc\n'), 'tcopy') + self.assertEqual(curs.rowcount, -1) @slow def test_copy_from_segfault(self): @@ -317,6 +315,7 @@ try: curs.execute("copy copy_segf from stdin") except psycopg2.ProgrammingError: pass +curs.close() conn.close() """ % {'dsn': dsn}) @@ -336,6 +335,7 @@ try: curs.execute("copy copy_segf to stdout") except psycopg2.ProgrammingError: pass +curs.close() conn.close() """ % {'dsn': dsn}) @@ -351,24 +351,24 @@ conn.close() def readline(self): return 1 / 0 - curs = self.conn.cursor() - # It seems we cannot do this, but now at least we propagate the error - # self.assertRaises(ZeroDivisionError, - # curs.copy_from, BrokenRead(), "tcopy") - try: - curs.copy_from(BrokenRead(), "tcopy") - except Exception as e: - self.assert_('ZeroDivisionError' in str(e)) + with self.conn.cursor() as curs: + # It seems we cannot do this, but now at least we propagate the error + # self.assertRaises(ZeroDivisionError, + # curs.copy_from, BrokenRead(), "tcopy") + try: + curs.copy_from(BrokenRead(), "tcopy") + except Exception as e: + self.assert_('ZeroDivisionError' in str(e)) def test_copy_to_propagate_error(self): class BrokenWrite(TextIOBase): def write(self, data): return 1 / 0 - curs = self.conn.cursor() - curs.execute("insert into tcopy values (10, 'hi')") - self.assertRaises(ZeroDivisionError, - curs.copy_to, BrokenWrite(), "tcopy") + with self.conn.cursor() as curs: + curs.execute("insert into tcopy values (10, 'hi')") + self.assertRaises(ZeroDivisionError, + curs.copy_to, BrokenWrite(), "tcopy") def test_suite(): diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 4d180962..c1e42137 100755 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -51,10 +51,10 @@ class CursorTests(ConnectingTestCase): self.assert_(cur.closed) def test_empty_query(self): - cur = self.conn.cursor() - self.assertRaises(psycopg2.ProgrammingError, cur.execute, "") - self.assertRaises(psycopg2.ProgrammingError, cur.execute, " ") - self.assertRaises(psycopg2.ProgrammingError, cur.execute, ";") + with self.conn.cursor() as cur: + self.assertRaises(psycopg2.ProgrammingError, cur.execute, "") + self.assertRaises(psycopg2.ProgrammingError, cur.execute, " ") + self.assertRaises(psycopg2.ProgrammingError, cur.execute, ";") def test_executemany_propagate_exceptions(self): conn = self.conn @@ -70,58 +70,57 @@ class CursorTests(ConnectingTestCase): def test_mogrify_unicode(self): conn = self.conn - cur = conn.cursor() + with conn.cursor() as cur: + # test consistency between execute and mogrify. - # test consistency between execute and mogrify. + # unicode query containing only ascii data + cur.execute(u"SELECT 'foo';") + self.assertEqual('foo', cur.fetchone()[0]) + self.assertEqual(b"SELECT 'foo';", cur.mogrify(u"SELECT 'foo';")) - # unicode query containing only ascii data - cur.execute(u"SELECT 'foo';") - self.assertEqual('foo', cur.fetchone()[0]) - self.assertEqual(b"SELECT 'foo';", cur.mogrify(u"SELECT 'foo';")) + conn.set_client_encoding('UTF8') + snowman = u"\u2603" - conn.set_client_encoding('UTF8') - snowman = u"\u2603" + def b(s): + if isinstance(s, text_type): + return s.encode('utf8') + else: + return s - def b(s): - if isinstance(s, text_type): - return s.encode('utf8') - else: - return s + # unicode query with non-ascii data + cur.execute(u"SELECT '%s';" % snowman) + self.assertEqual(snowman.encode('utf8'), b(cur.fetchone()[0])) + self.assertQuotedEqual(("SELECT '%s';" % snowman).encode('utf8'), + cur.mogrify(u"SELECT '%s';" % snowman)) - # unicode query with non-ascii data - cur.execute(u"SELECT '%s';" % snowman) - self.assertEqual(snowman.encode('utf8'), b(cur.fetchone()[0])) - self.assertQuotedEqual(("SELECT '%s';" % snowman).encode('utf8'), - cur.mogrify(u"SELECT '%s';" % snowman)) + # unicode args + cur.execute("SELECT %s;", (snowman,)) + self.assertEqual(snowman.encode("utf-8"), b(cur.fetchone()[0])) + self.assertQuotedEqual(("SELECT '%s';" % snowman).encode('utf8'), + cur.mogrify("SELECT %s;", (snowman,))) - # unicode args - cur.execute("SELECT %s;", (snowman,)) - self.assertEqual(snowman.encode("utf-8"), b(cur.fetchone()[0])) - self.assertQuotedEqual(("SELECT '%s';" % snowman).encode('utf8'), - cur.mogrify("SELECT %s;", (snowman,))) - - # unicode query and args - cur.execute(u"SELECT %s;", (snowman,)) - self.assertEqual(snowman.encode("utf-8"), b(cur.fetchone()[0])) - self.assertQuotedEqual(("SELECT '%s';" % snowman).encode('utf8'), - cur.mogrify(u"SELECT %s;", (snowman,))) + # unicode query and args + cur.execute(u"SELECT %s;", (snowman,)) + self.assertEqual(snowman.encode("utf-8"), b(cur.fetchone()[0])) + self.assertQuotedEqual(("SELECT '%s';" % snowman).encode('utf8'), + cur.mogrify(u"SELECT %s;", (snowman,))) def test_mogrify_decimal_explodes(self): conn = self.conn - cur = conn.cursor() - self.assertEqual(b'SELECT 10.3;', - cur.mogrify("SELECT %s;", (Decimal("10.3"),))) + with conn.cursor() as cur: + self.assertEqual(b'SELECT 10.3;', + cur.mogrify("SELECT %s;", (Decimal("10.3"),))) @skip_if_no_getrefcount def test_mogrify_leak_on_multiple_reference(self): # issue #81: reference leak when a parameter value is referenced # more than once from a dict. - cur = self.conn.cursor() - foo = (lambda x: x)('foo') * 10 - nref1 = sys.getrefcount(foo) - cur.mogrify("select %(foo)s, %(foo)s, %(foo)s", {'foo': foo}) - nref2 = sys.getrefcount(foo) - self.assertEqual(nref1, nref2) + with self.conn.cursor() as cur: + foo = (lambda x: x)('foo') * 10 + nref1 = sys.getrefcount(foo) + cur.mogrify("select %(foo)s, %(foo)s, %(foo)s", {'foo': foo}) + nref2 = sys.getrefcount(foo) + self.assertEqual(nref1, nref2) def test_modify_closed(self): cur = self.conn.cursor() @@ -130,52 +129,51 @@ class CursorTests(ConnectingTestCase): self.assertEqual(sql, b"select 10") def test_bad_placeholder(self): - cur = self.conn.cursor() - self.assertRaises(psycopg2.ProgrammingError, - cur.mogrify, "select %(foo", {}) - self.assertRaises(psycopg2.ProgrammingError, - cur.mogrify, "select %(foo", {'foo': 1}) - self.assertRaises(psycopg2.ProgrammingError, - cur.mogrify, "select %(foo, %(bar)", {'foo': 1}) - self.assertRaises(psycopg2.ProgrammingError, - cur.mogrify, "select %(foo, %(bar)", {'foo': 1, 'bar': 2}) + with self.conn.cursor() as cur: + self.assertRaises(psycopg2.ProgrammingError, + cur.mogrify, "select %(foo", {}) + self.assertRaises(psycopg2.ProgrammingError, + cur.mogrify, "select %(foo", {'foo': 1}) + self.assertRaises(psycopg2.ProgrammingError, + cur.mogrify, "select %(foo, %(bar)", {'foo': 1}) + self.assertRaises(psycopg2.ProgrammingError, + cur.mogrify, "select %(foo, %(bar)", {'foo': 1, 'bar': 2}) def test_cast(self): - curs = self.conn.cursor() + with self.conn.cursor() as curs: + self.assertEqual(42, curs.cast(20, '42')) + self.assertAlmostEqual(3.14, curs.cast(700, '3.14')) - self.assertEqual(42, curs.cast(20, '42')) - self.assertAlmostEqual(3.14, curs.cast(700, '3.14')) + self.assertEqual(Decimal('123.45'), curs.cast(1700, '123.45')) - self.assertEqual(Decimal('123.45'), curs.cast(1700, '123.45')) - - self.assertEqual(date(2011, 1, 2), curs.cast(1082, '2011-01-02')) - self.assertEqual("who am i?", curs.cast(705, 'who am i?')) # unknown + self.assertEqual(date(2011, 1, 2), curs.cast(1082, '2011-01-02')) + self.assertEqual("who am i?", curs.cast(705, 'who am i?')) # unknown def test_cast_specificity(self): - curs = self.conn.cursor() - self.assertEqual("foo", curs.cast(705, 'foo')) + with self.conn.cursor() as curs: + self.assertEqual("foo", curs.cast(705, 'foo')) - D = psycopg2.extensions.new_type((705,), "DOUBLING", lambda v, c: v * 2) - psycopg2.extensions.register_type(D, self.conn) - self.assertEqual("foofoo", curs.cast(705, 'foo')) + D = psycopg2.extensions.new_type((705,), "DOUBLING", lambda v, c: v * 2) + psycopg2.extensions.register_type(D, self.conn) + self.assertEqual("foofoo", curs.cast(705, 'foo')) - T = psycopg2.extensions.new_type((705,), "TREBLING", lambda v, c: v * 3) - psycopg2.extensions.register_type(T, curs) - self.assertEqual("foofoofoo", curs.cast(705, 'foo')) + T = psycopg2.extensions.new_type((705,), "TREBLING", lambda v, c: v * 3) + psycopg2.extensions.register_type(T, curs) + self.assertEqual("foofoofoo", curs.cast(705, 'foo')) - curs2 = self.conn.cursor() - self.assertEqual("foofoo", curs2.cast(705, 'foo')) + with self.conn.cursor() as curs2: + self.assertEqual("foofoo", curs2.cast(705, 'foo')) def test_weakref(self): - curs = self.conn.cursor() - w = ref(curs) + with self.conn.cursor() as curs: + w = ref(curs) del curs gc.collect() self.assert_(w() is None) def test_null_name(self): - curs = self.conn.cursor(None) - self.assertEqual(curs.name, None) + with self.conn.cursor(None) as curs: + self.assertEqual(curs.name, None) def test_invalid_name(self): curs = self.conn.cursor() @@ -184,9 +182,9 @@ class CursorTests(ConnectingTestCase): curs.execute("insert into invname values (%s)", (i,)) curs.close() - curs = self.conn.cursor(r'1-2-3 \ "test"') - curs.execute("select data from invname order by data") - self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)]) + with self.conn.cursor(r'1-2-3 \ "test"') as curs: + curs.execute("select data from invname order by data") + self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)]) def _create_withhold_table(self): curs = self.conn.cursor() @@ -213,15 +211,15 @@ class CursorTests(ConnectingTestCase): self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)]) curs.close() - curs = self.conn.cursor("W", withhold=True) - self.assertEqual(curs.withhold, True) - curs.execute("select data from withhold order by data") - self.conn.commit() - self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)]) + with self.conn.cursor("W", withhold=True) as curs: + self.assertEqual(curs.withhold, True) + curs.execute("select data from withhold order by data") + self.conn.commit() + self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)]) - curs = self.conn.cursor() - curs.execute("drop table withhold") - self.conn.commit() + with self.conn.cursor() as curs: + curs.execute("drop table withhold") + self.conn.commit() def test_withhold_no_begin(self): self._create_withhold_table() @@ -328,134 +326,134 @@ class CursorTests(ConnectingTestCase): return self.skipTest("can't evaluate non-scrollable cursor") curs.close() - curs = self.conn.cursor("S", scrollable=False) - self.assertEqual(curs.scrollable, False) - curs.execute("select * from scrollable") - curs.scroll(2) - self.assertRaises(psycopg2.OperationalError, curs.scroll, -1) + with self.conn.cursor("S", scrollable=False) as curs: + self.assertEqual(curs.scrollable, False) + curs.execute("select * from scrollable") + curs.scroll(2) + self.assertRaises(psycopg2.OperationalError, curs.scroll, -1) @slow @skip_before_postgres(8, 2) def test_iter_named_cursor_efficient(self): - curs = self.conn.cursor('tmp') - # if these records are fetched in the same roundtrip their - # timestamp will not be influenced by the pause in Python world. - curs.execute("""select clock_timestamp() from generate_series(1,2)""") - i = iter(curs) - t1 = next(i)[0] - time.sleep(0.2) - t2 = next(i)[0] - self.assert_((t2 - t1).microseconds * 1e-6 < 0.1, - "named cursor records fetched in 2 roundtrips (delta: %s)" - % (t2 - t1)) + with self.conn.cursor('tmp') as curs: + # if these records are fetched in the same roundtrip their + # timestamp will not be influenced by the pause in Python world. + curs.execute("""select clock_timestamp() from generate_series(1,2)""") + i = iter(curs) + t1 = next(i)[0] + time.sleep(0.2) + t2 = next(i)[0] + self.assert_((t2 - t1).microseconds * 1e-6 < 0.1, + "named cursor records fetched in 2 roundtrips (delta: %s)" + % (t2 - t1)) @skip_before_postgres(8, 0) def test_iter_named_cursor_default_itersize(self): - curs = self.conn.cursor('tmp') - curs.execute('select generate_series(1,50)') - rv = [(r[0], curs.rownumber) for r in curs] - # everything swallowed in one gulp - self.assertEqual(rv, [(i, i) for i in range(1, 51)]) + with self.conn.cursor('tmp') as curs: + curs.execute('select generate_series(1,50)') + rv = [(r[0], curs.rownumber) for r in curs] + # everything swallowed in one gulp + self.assertEqual(rv, [(i, i) for i in range(1, 51)]) @skip_before_postgres(8, 0) def test_iter_named_cursor_itersize(self): - curs = self.conn.cursor('tmp') - curs.itersize = 30 - curs.execute('select generate_series(1,50)') - rv = [(r[0], curs.rownumber) for r in curs] - # everything swallowed in two gulps - self.assertEqual(rv, [(i, ((i - 1) % 30) + 1) for i in range(1, 51)]) + with self.conn.cursor('tmp') as curs: + curs.itersize = 30 + curs.execute('select generate_series(1,50)') + rv = [(r[0], curs.rownumber) for r in curs] + # everything swallowed in two gulps + self.assertEqual(rv, [(i, ((i - 1) % 30) + 1) for i in range(1, 51)]) @skip_before_postgres(8, 0) def test_iter_named_cursor_rownumber(self): - curs = self.conn.cursor('tmp') - # note: this fails if itersize < dataset: internally we check - # rownumber == rowcount to detect when to read anoter page, so we - # would need an extra attribute to have a monotonic rownumber. - curs.itersize = 20 - curs.execute('select generate_series(1,10)') - for i, rec in enumerate(curs): - self.assertEqual(i + 1, curs.rownumber) + with self.conn.cursor('tmp') as curs: + # note: this fails if itersize < dataset: internally we check + # rownumber == rowcount to detect when to read anoter page, so we + # would need an extra attribute to have a monotonic rownumber. + curs.itersize = 20 + curs.execute('select generate_series(1,10)') + for i, rec in enumerate(curs): + self.assertEqual(i + 1, curs.rownumber) def test_description_attribs(self): - curs = self.conn.cursor() - curs.execute("""select - 3.14::decimal(10,2) as pi, - 'hello'::text as hi, - '2010-02-18'::date as now; - """) - self.assertEqual(len(curs.description), 3) - for c in curs.description: - self.assertEqual(len(c), 7) # DBAPI happy - for a in ('name', 'type_code', 'display_size', 'internal_size', - 'precision', 'scale', 'null_ok'): - self.assert_(hasattr(c, a), a) + with self.conn.cursor() as curs: + curs.execute("""select + 3.14::decimal(10,2) as pi, + 'hello'::text as hi, + '2010-02-18'::date as now; + """) + self.assertEqual(len(curs.description), 3) + for c in curs.description: + self.assertEqual(len(c), 7) # DBAPI happy + for a in ('name', 'type_code', 'display_size', 'internal_size', + 'precision', 'scale', 'null_ok'): + self.assert_(hasattr(c, a), a) - c = curs.description[0] - self.assertEqual(c.name, 'pi') - self.assert_(c.type_code in psycopg2.extensions.DECIMAL.values) - self.assert_(c.internal_size > 0) - self.assertEqual(c.precision, 10) - self.assertEqual(c.scale, 2) + c = curs.description[0] + self.assertEqual(c.name, 'pi') + self.assert_(c.type_code in psycopg2.extensions.DECIMAL.values) + self.assert_(c.internal_size > 0) + self.assertEqual(c.precision, 10) + self.assertEqual(c.scale, 2) - c = curs.description[1] - self.assertEqual(c.name, 'hi') - self.assert_(c.type_code in psycopg2.STRING.values) - self.assert_(c.internal_size < 0) - self.assertEqual(c.precision, None) - self.assertEqual(c.scale, None) + c = curs.description[1] + self.assertEqual(c.name, 'hi') + self.assert_(c.type_code in psycopg2.STRING.values) + self.assert_(c.internal_size < 0) + self.assertEqual(c.precision, None) + self.assertEqual(c.scale, None) - c = curs.description[2] - self.assertEqual(c.name, 'now') - self.assert_(c.type_code in psycopg2.extensions.DATE.values) - self.assert_(c.internal_size > 0) - self.assertEqual(c.precision, None) - self.assertEqual(c.scale, None) + c = curs.description[2] + self.assertEqual(c.name, 'now') + self.assert_(c.type_code in psycopg2.extensions.DATE.values) + self.assert_(c.internal_size > 0) + self.assertEqual(c.precision, None) + self.assertEqual(c.scale, None) def test_description_extra_attribs(self): - curs = self.conn.cursor() - curs.execute(""" - create table testcol ( - pi decimal(10,2), - hi text) - """) - curs.execute("select oid from pg_class where relname = %s", ('testcol',)) - oid = curs.fetchone()[0] + with self.conn.cursor() as curs: + curs.execute(""" + create table testcol ( + pi decimal(10,2), + hi text) + """) + curs.execute("select oid from pg_class where relname = %s", ('testcol',)) + oid = curs.fetchone()[0] - curs.execute("insert into testcol values (3.14, 'hello')") - curs.execute("select hi, pi, 42 from testcol") - self.assertEqual(curs.description[0].table_oid, oid) - self.assertEqual(curs.description[0].table_column, 2) + curs.execute("insert into testcol values (3.14, 'hello')") + curs.execute("select hi, pi, 42 from testcol") + self.assertEqual(curs.description[0].table_oid, oid) + self.assertEqual(curs.description[0].table_column, 2) - self.assertEqual(curs.description[1].table_oid, oid) - self.assertEqual(curs.description[1].table_column, 1) + self.assertEqual(curs.description[1].table_oid, oid) + self.assertEqual(curs.description[1].table_column, 1) - self.assertEqual(curs.description[2].table_oid, None) - self.assertEqual(curs.description[2].table_column, None) + self.assertEqual(curs.description[2].table_oid, None) + self.assertEqual(curs.description[2].table_column, None) def test_pickle_description(self): - curs = self.conn.cursor() - curs.execute('SELECT 1 AS foo') - description = curs.description + with self.conn.cursor() as curs: + curs.execute('SELECT 1 AS foo') + description = curs.description - pickled = pickle.dumps(description, pickle.HIGHEST_PROTOCOL) - unpickled = pickle.loads(pickled) + pickled = pickle.dumps(description, pickle.HIGHEST_PROTOCOL) + unpickled = pickle.loads(pickled) - self.assertEqual(description, unpickled) + self.assertEqual(description, unpickled) @skip_before_postgres(8, 0) def test_named_cursor_stealing(self): # you can use a named cursor to iterate on a refcursor created # somewhere else - cur1 = self.conn.cursor() - cur1.execute("DECLARE test CURSOR WITHOUT HOLD " - " FOR SELECT generate_series(1,7)") + with self.conn.cursor() as cur1: + cur1.execute("DECLARE test CURSOR WITHOUT HOLD " + " FOR SELECT generate_series(1,7)") - cur2 = self.conn.cursor('test') - # can call fetch without execute - self.assertEqual((1,), cur2.fetchone()) - self.assertEqual([(2,), (3,), (4,)], cur2.fetchmany(3)) - self.assertEqual([(5,), (6,), (7,)], cur2.fetchall()) + with self.conn.cursor('test') as cur2: + # can call fetch without execute + self.assertEqual((1,), cur2.fetchone()) + self.assertEqual([(2,), (3,), (4,)], cur2.fetchmany(3)) + self.assertEqual([(5,), (6,), (7,)], cur2.fetchall()) @skip_before_postgres(8, 2) def test_named_noop_close(self): @@ -464,63 +462,63 @@ class CursorTests(ConnectingTestCase): @skip_before_postgres(8, 2) def test_stolen_named_cursor_close(self): - cur1 = self.conn.cursor() - cur1.execute("DECLARE test CURSOR WITHOUT HOLD " - " FOR SELECT generate_series(1,7)") - cur2 = self.conn.cursor('test') - cur2.close() + with self.conn.cursor() as cur1: + cur1.execute("DECLARE test CURSOR WITHOUT HOLD " + " FOR SELECT generate_series(1,7)") + cur2 = self.conn.cursor('test') + cur2.close() - cur1.execute("DECLARE test CURSOR WITHOUT HOLD " - " FOR SELECT generate_series(1,7)") - cur2 = self.conn.cursor('test') - cur2.close() + cur1.execute("DECLARE test CURSOR WITHOUT HOLD " + " FOR SELECT generate_series(1,7)") + cur2 = self.conn.cursor('test') + cur2.close() @skip_before_postgres(8, 0) def test_scroll(self): - cur = self.conn.cursor() - cur.execute("select generate_series(0,9)") - cur.scroll(2) - self.assertEqual(cur.fetchone(), (2,)) - cur.scroll(2) - self.assertEqual(cur.fetchone(), (5,)) - cur.scroll(2, mode='relative') - self.assertEqual(cur.fetchone(), (8,)) - cur.scroll(-1) - self.assertEqual(cur.fetchone(), (8,)) - cur.scroll(-2) - self.assertEqual(cur.fetchone(), (7,)) - cur.scroll(2, mode='absolute') - self.assertEqual(cur.fetchone(), (2,)) + with self.conn.cursor() as cur: + cur.execute("select generate_series(0,9)") + cur.scroll(2) + self.assertEqual(cur.fetchone(), (2,)) + cur.scroll(2) + self.assertEqual(cur.fetchone(), (5,)) + cur.scroll(2, mode='relative') + self.assertEqual(cur.fetchone(), (8,)) + cur.scroll(-1) + self.assertEqual(cur.fetchone(), (8,)) + cur.scroll(-2) + self.assertEqual(cur.fetchone(), (7,)) + cur.scroll(2, mode='absolute') + self.assertEqual(cur.fetchone(), (2,)) - # on the boundary - cur.scroll(0, mode='absolute') - self.assertEqual(cur.fetchone(), (0,)) - self.assertRaises((IndexError, psycopg2.ProgrammingError), - cur.scroll, -1, mode='absolute') - cur.scroll(0, mode='absolute') - self.assertRaises((IndexError, psycopg2.ProgrammingError), - cur.scroll, -1) + # on the boundary + cur.scroll(0, mode='absolute') + self.assertEqual(cur.fetchone(), (0,)) + self.assertRaises((IndexError, psycopg2.ProgrammingError), + cur.scroll, -1, mode='absolute') + cur.scroll(0, mode='absolute') + self.assertRaises((IndexError, psycopg2.ProgrammingError), + cur.scroll, -1) - cur.scroll(9, mode='absolute') - self.assertEqual(cur.fetchone(), (9,)) - self.assertRaises((IndexError, psycopg2.ProgrammingError), - cur.scroll, 10, mode='absolute') - cur.scroll(9, mode='absolute') - self.assertRaises((IndexError, psycopg2.ProgrammingError), - cur.scroll, 1) + cur.scroll(9, mode='absolute') + self.assertEqual(cur.fetchone(), (9,)) + self.assertRaises((IndexError, psycopg2.ProgrammingError), + cur.scroll, 10, mode='absolute') + cur.scroll(9, mode='absolute') + self.assertRaises((IndexError, psycopg2.ProgrammingError), + cur.scroll, 1) @skip_before_postgres(8, 0) def test_scroll_named(self): - cur = self.conn.cursor('tmp', scrollable=True) - cur.execute("select generate_series(0,9)") - cur.scroll(2) - self.assertEqual(cur.fetchone(), (2,)) - cur.scroll(2) - self.assertEqual(cur.fetchone(), (5,)) - cur.scroll(2, mode='relative') - self.assertEqual(cur.fetchone(), (8,)) - cur.scroll(9, mode='absolute') - self.assertEqual(cur.fetchone(), (9,)) + with self.conn.cursor('tmp', scrollable=True) as cur: + cur.execute("select generate_series(0,9)") + cur.scroll(2) + self.assertEqual(cur.fetchone(), (2,)) + cur.scroll(2) + self.assertEqual(cur.fetchone(), (5,)) + cur.scroll(2, mode='relative') + self.assertEqual(cur.fetchone(), (8,)) + cur.scroll(9, mode='absolute') + self.assertEqual(cur.fetchone(), (9,)) def test_bad_subclass(self): # check that we get an error message instead of a segfault @@ -531,14 +529,14 @@ class CursorTests(ConnectingTestCase): # I am stupid so not calling superclass init pass - cur = StupidCursor() - self.assertRaises(psycopg2.InterfaceError, cur.execute, 'select 1') - self.assertRaises(psycopg2.InterfaceError, cur.executemany, - 'select 1', []) + with StupidCursor() as cur: + self.assertRaises(psycopg2.InterfaceError, cur.execute, 'select 1') + self.assertRaises(psycopg2.InterfaceError, cur.executemany, + 'select 1', []) def test_callproc_badparam(self): - cur = self.conn.cursor() - self.assertRaises(TypeError, cur.callproc, 'lower', 42) + with self.conn.cursor() as cur: + self.assertRaises(TypeError, cur.callproc, 'lower', 42) # It would be inappropriate to test callproc's named parameters in the # DBAPI2.0 test section because they are a psycopg2 extension. @@ -551,32 +549,31 @@ class CursorTests(ConnectingTestCase): escaped_paramname = '"%s"' % paramname.replace('"', '""') procname = 'pg_temp.randall' - cur = self.conn.cursor() + with self.conn.cursor() as cur: + # Set up the temporary function + cur.execute(''' + CREATE FUNCTION %s(%s INT) + RETURNS INT AS + 'SELECT $1 * $1' + LANGUAGE SQL + ''' % (procname, escaped_paramname)) - # Set up the temporary function - cur.execute(''' - CREATE FUNCTION %s(%s INT) - RETURNS INT AS - 'SELECT $1 * $1' - LANGUAGE SQL - ''' % (procname, escaped_paramname)) + # Make sure callproc works right + cur.callproc(procname, {paramname: 2}) + self.assertEquals(cur.fetchone()[0], 4) - # Make sure callproc works right - cur.callproc(procname, {paramname: 2}) - self.assertEquals(cur.fetchone()[0], 4) - - # Make sure callproc fails right - failing_cases = [ - ({paramname: 2, 'foo': 'bar'}, psycopg2.ProgrammingError), - ({paramname: '2'}, psycopg2.ProgrammingError), - ({paramname: 'two'}, psycopg2.ProgrammingError), - ({u'bj\xc3rn': 2}, psycopg2.ProgrammingError), - ({3: 2}, TypeError), - ({self: 2}, TypeError), - ] - for parameter_sequence, exception in failing_cases: - self.assertRaises(exception, cur.callproc, procname, parameter_sequence) - self.conn.rollback() + # Make sure callproc fails right + failing_cases = [ + ({paramname: 2, 'foo': 'bar'}, psycopg2.ProgrammingError), + ({paramname: '2'}, psycopg2.ProgrammingError), + ({paramname: 'two'}, psycopg2.ProgrammingError), + ({u'bj\xc3rn': 2}, psycopg2.ProgrammingError), + ({3: 2}, TypeError), + ({self: 2}, TypeError), + ] + for parameter_sequence, exception in failing_cases: + self.assertRaises(exception, cur.callproc, procname, parameter_sequence) + self.conn.rollback() @skip_if_no_superuser @skip_if_windows @@ -638,17 +635,17 @@ class CursorTests(ConnectingTestCase): @skip_before_postgres(8, 2) def test_rowcount_on_executemany_returning(self): - cur = self.conn.cursor() - cur.execute("create table execmany(id serial primary key, data int)") - cur.executemany( - "insert into execmany (data) values (%s)", - [(i,) for i in range(4)]) - self.assertEqual(cur.rowcount, 4) + with self.conn.cursor() as cur: + cur.execute("create table execmany(id serial primary key, data int)") + cur.executemany( + "insert into execmany (data) values (%s)", + [(i,) for i in range(4)]) + self.assertEqual(cur.rowcount, 4) - cur.executemany( - "insert into execmany (data) values (%s) returning data", - [(i,) for i in range(5)]) - self.assertEqual(cur.rowcount, 5) + cur.executemany( + "insert into execmany (data) values (%s) returning data", + [(i,) for i in range(5)]) + self.assertEqual(cur.rowcount, 5) @skip_before_postgres(9) def test_pgresult_ptr(self): diff --git a/tests/test_dates.py b/tests/test_dates.py index b9aac695..829461b9 100755 --- a/tests/test_dates.py +++ b/tests/test_dates.py @@ -369,22 +369,22 @@ class DatetimeTests(ConnectingTestCase, CommonDatetimeTestsMixin): self.assertEqual(total_seconds(t), 1e-6) def test_interval_overflow(self): - cur = self.conn.cursor() - # hack a cursor to receive values too extreme to be represented - # but still I want an error, not a random number - psycopg2.extensions.register_type( - psycopg2.extensions.new_type( - psycopg2.STRING.values, 'WAT', psycopg2.extensions.INTERVAL), - cur) + with self.conn.cursor() as cur: + # hack a cursor to receive values too extreme to be represented + # but still I want an error, not a random number + psycopg2.extensions.register_type( + psycopg2.extensions.new_type( + psycopg2.STRING.values, 'WAT', psycopg2.extensions.INTERVAL), + cur) - def f(val): - cur.execute("select '%s'::text" % val) - return cur.fetchone()[0] + def f(val): + cur.execute("select '%s'::text" % val) + return cur.fetchone()[0] - self.assertRaises(OverflowError, f, '100000000000000000:00:00') - self.assertRaises(OverflowError, f, '00:100000000000000000:00:00') - self.assertRaises(OverflowError, f, '00:00:100000000000000000:00') - self.assertRaises(OverflowError, f, '00:00:00.100000000000000000') + self.assertRaises(OverflowError, f, '100000000000000000:00:00') + self.assertRaises(OverflowError, f, '00:100000000000000000:00:00') + self.assertRaises(OverflowError, f, '00:00:100000000000000000:00') + self.assertRaises(OverflowError, f, '00:00:00.100000000000000000') def test_adapt_infinity_tz(self): t = self.execute("select 'infinity'::timestamp") @@ -405,31 +405,31 @@ class DatetimeTests(ConnectingTestCase, CommonDatetimeTestsMixin): def test_redshift_day(self): # Redshift is reported returning 1 day interval as microsec (bug #558) - cur = self.conn.cursor() - psycopg2.extensions.register_type( - psycopg2.extensions.new_type( - psycopg2.STRING.values, 'WAT', psycopg2.extensions.INTERVAL), - cur) + with self.conn.cursor() as cur: + psycopg2.extensions.register_type( + psycopg2.extensions.new_type( + psycopg2.STRING.values, 'WAT', psycopg2.extensions.INTERVAL), + cur) - for s, v in [ - ('0', timedelta(0)), - ('1', timedelta(microseconds=1)), - ('-1', timedelta(microseconds=-1)), - ('1000000', timedelta(seconds=1)), - ('86400000000', timedelta(days=1)), - ('-86400000000', timedelta(days=-1)), - ]: - cur.execute("select %s::text", (s,)) - r = cur.fetchone()[0] - self.assertEqual(r, v, "%s -> %s != %s" % (s, r, v)) + for s, v in [ + ('0', timedelta(0)), + ('1', timedelta(microseconds=1)), + ('-1', timedelta(microseconds=-1)), + ('1000000', timedelta(seconds=1)), + ('86400000000', timedelta(days=1)), + ('-86400000000', timedelta(days=-1)), + ]: + cur.execute("select %s::text", (s,)) + r = cur.fetchone()[0] + self.assertEqual(r, v, "%s -> %s != %s" % (s, r, v)) @skip_before_postgres(8, 4) def test_interval_iso_8601_not_supported(self): # We may end up supporting, but no pressure for it - cur = self.conn.cursor() - cur.execute("set local intervalstyle to iso_8601") - cur.execute("select '1 day 2 hours'::interval") - self.assertRaises(psycopg2.NotSupportedError, cur.fetchone) + with self.conn.cursor() as cur: + cur.execute("set local intervalstyle to iso_8601") + cur.execute("select '1 day 2 hours'::interval") + self.assertRaises(psycopg2.NotSupportedError, cur.fetchone) @unittest.skipUnless( diff --git a/tests/test_extras_dictcursor.py b/tests/test_extras_dictcursor.py index d4bb12f5..80fd4cca 100755 --- a/tests/test_extras_dictcursor.py +++ b/tests/test_extras_dictcursor.py @@ -32,9 +32,9 @@ from .testutils import ConnectingTestCase, skip_before_postgres, \ class _DictCursorBase(ConnectingTestCase): def setUp(self): ConnectingTestCase.setUp(self) - curs = self.conn.cursor() - curs.execute("CREATE TEMPORARY TABLE ExtrasDictCursorTests (foo text)") - curs.execute("INSERT INTO ExtrasDictCursorTests VALUES ('bar')") + with self.conn.cursor() as curs: + curs.execute("CREATE TEMPORARY TABLE ExtrasDictCursorTests (foo text)") + curs.execute("INSERT INTO ExtrasDictCursorTests VALUES ('bar')") self.conn.commit() def _testIterRowNumber(self, curs): @@ -61,17 +61,20 @@ class _DictCursorBase(ConnectingTestCase): class ExtrasDictCursorTests(_DictCursorBase): """Test if DictCursor extension class works.""" + @skip_before_postgres(8, 2) def testDictConnCursorArgs(self): self.conn.close() self.conn = self.connect(connection_factory=psycopg2.extras.DictConnection) - cur = self.conn.cursor() - self.assert_(isinstance(cur, psycopg2.extras.DictCursor)) - self.assertEqual(cur.name, None) - # overridable - cur = self.conn.cursor('foo', - cursor_factory=psycopg2.extras.NamedTupleCursor) - self.assertEqual(cur.name, 'foo') - self.assert_(isinstance(cur, psycopg2.extras.NamedTupleCursor)) + with self.conn.cursor() as cur: + self.assert_(isinstance(cur, psycopg2.extras.DictCursor)) + self.assertEqual(cur.name, None) + # overridable + with self.conn.cursor( + 'foo', + cursor_factory=psycopg2.extras.NamedTupleCursor + ) as cur: + self.assertEqual(cur.name, 'foo') + self.assert_(isinstance(cur, psycopg2.extras.NamedTupleCursor)) def testDictCursorWithPlainCursorFetchOne(self): self._testWithPlainCursor(lambda curs: curs.fetchone()) @@ -99,13 +102,13 @@ class ExtrasDictCursorTests(_DictCursorBase): @skip_before_postgres(8, 0) def testDictCursorWithPlainCursorIterRowNumber(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) - self._testIterRowNumber(curs) + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs: + self._testIterRowNumber(curs) def _testWithPlainCursor(self, getter): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) - curs.execute("SELECT * FROM ExtrasDictCursorTests") - row = getter(curs) + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs: + curs.execute("SELECT * FROM ExtrasDictCursorTests") + row = getter(curs) self.failUnless(row['foo'] == 'bar') self.failUnless(row[0] == 'bar') return row @@ -130,80 +133,89 @@ class ExtrasDictCursorTests(_DictCursorBase): @skip_before_postgres(8, 2) def testDictCursorWithNamedCursorNotGreedy(self): - curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.DictCursor) - self._testNamedCursorNotGreedy(curs) + with self.conn.cursor( + 'tmp', + cursor_factory=psycopg2.extras.DictCursor + ) as curs: + self._testNamedCursorNotGreedy(curs) @skip_before_postgres(8, 0) def testDictCursorWithNamedCursorIterRowNumber(self): - curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.DictCursor) - self._testIterRowNumber(curs) + with self.conn.cursor( + 'tmp', + cursor_factory=psycopg2.extras.DictCursor + ) as curs: + self._testIterRowNumber(curs) def _testWithNamedCursor(self, getter): - curs = self.conn.cursor('aname', cursor_factory=psycopg2.extras.DictCursor) - curs.execute("SELECT * FROM ExtrasDictCursorTests") - row = getter(curs) - self.failUnless(row['foo'] == 'bar') - self.failUnless(row[0] == 'bar') + with self.conn.cursor( + 'aname', + cursor_factory=psycopg2.extras.DictCursor + ) as curs: + curs.execute("SELECT * FROM ExtrasDictCursorTests") + row = getter(curs) + self.failUnless(row['foo'] == 'bar') + self.failUnless(row[0] == 'bar') def testPickleDictRow(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) - curs.execute("select 10 as a, 20 as b") - r = curs.fetchone() - d = pickle.dumps(r) - r1 = pickle.loads(d) - self.assertEqual(r, r1) - self.assertEqual(r[0], r1[0]) - self.assertEqual(r[1], r1[1]) - self.assertEqual(r['a'], r1['a']) - self.assertEqual(r['b'], r1['b']) - self.assertEqual(r._index, r1._index) + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs: + curs.execute("select 10 as a, 20 as b") + r = curs.fetchone() + d = pickle.dumps(r) + r1 = pickle.loads(d) + self.assertEqual(r, r1) + self.assertEqual(r[0], r1[0]) + self.assertEqual(r[1], r1[1]) + self.assertEqual(r['a'], r1['a']) + self.assertEqual(r['b'], r1['b']) + self.assertEqual(r._index, r1._index) @skip_from_python(3) def test_iter_methods_2(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) - curs.execute("select 10 as a, 20 as b") - r = curs.fetchone() - self.assert_(isinstance(r.keys(), list)) - self.assertEqual(len(r.keys()), 2) - self.assert_(isinstance(r.values(), tuple)) # sic? - self.assertEqual(len(r.values()), 2) - self.assert_(isinstance(r.items(), list)) - self.assertEqual(len(r.items()), 2) + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs: + curs.execute("select 10 as a, 20 as b") + r = curs.fetchone() + self.assert_(isinstance(r.keys(), list)) + self.assertEqual(len(r.keys()), 2) + self.assert_(isinstance(r.values(), tuple)) # sic? + self.assertEqual(len(r.values()), 2) + self.assert_(isinstance(r.items(), list)) + self.assertEqual(len(r.items()), 2) - self.assert_(not isinstance(r.iterkeys(), list)) - self.assertEqual(len(list(r.iterkeys())), 2) - self.assert_(not isinstance(r.itervalues(), list)) - self.assertEqual(len(list(r.itervalues())), 2) - self.assert_(not isinstance(r.iteritems(), list)) - self.assertEqual(len(list(r.iteritems())), 2) + self.assert_(not isinstance(r.iterkeys(), list)) + self.assertEqual(len(list(r.iterkeys())), 2) + self.assert_(not isinstance(r.itervalues(), list)) + self.assertEqual(len(list(r.itervalues())), 2) + self.assert_(not isinstance(r.iteritems(), list)) + self.assertEqual(len(list(r.iteritems())), 2) @skip_before_python(3) def test_iter_methods_3(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) - curs.execute("select 10 as a, 20 as b") - r = curs.fetchone() - self.assert_(not isinstance(r.keys(), list)) - self.assertEqual(len(list(r.keys())), 2) - self.assert_(not isinstance(r.values(), list)) - self.assertEqual(len(list(r.values())), 2) - self.assert_(not isinstance(r.items(), list)) - self.assertEqual(len(list(r.items())), 2) + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs: + curs.execute("select 10 as a, 20 as b") + r = curs.fetchone() + self.assert_(not isinstance(r.keys(), list)) + self.assertEqual(len(list(r.keys())), 2) + self.assert_(not isinstance(r.values(), list)) + self.assertEqual(len(list(r.values())), 2) + self.assert_(not isinstance(r.items(), list)) + self.assertEqual(len(list(r.items())), 2) def test_order(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) - curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux") - r = curs.fetchone() - self.assertEqual(list(r), [5, 4, 33, 2]) - self.assertEqual(list(r.keys()), ['foo', 'bar', 'baz', 'qux']) - self.assertEqual(list(r.values()), [5, 4, 33, 2]) - self.assertEqual(list(r.items()), - [('foo', 5), ('bar', 4), ('baz', 33), ('qux', 2)]) + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs: + curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux") + r = curs.fetchone() + self.assertEqual(list(r), [5, 4, 33, 2]) + self.assertEqual(list(r.keys()), ['foo', 'bar', 'baz', 'qux']) + self.assertEqual(list(r.values()), [5, 4, 33, 2]) + self.assertEqual(list(r.items()), + [('foo', 5), ('bar', 4), ('baz', 33), ('qux', 2)]) - r1 = pickle.loads(pickle.dumps(r)) - self.assertEqual(list(r1), list(r)) - self.assertEqual(list(r1.keys()), list(r.keys())) - self.assertEqual(list(r1.values()), list(r.values())) - self.assertEqual(list(r1.items()), list(r.items())) + r1 = pickle.loads(pickle.dumps(r)) + self.assertEqual(list(r1), list(r)) + self.assertEqual(list(r1.keys()), list(r.keys())) + self.assertEqual(list(r1.values()), list(r.values())) + self.assertEqual(list(r1.items()), list(r.items())) @skip_from_python(3) def test_order_iter(self): @@ -223,10 +235,10 @@ class ExtrasDictCursorTests(_DictCursorBase): class ExtrasDictCursorRealTests(_DictCursorBase): def testRealMeansReal(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) - curs.execute("SELECT * FROM ExtrasDictCursorTests") - row = curs.fetchone() - self.assert_(isinstance(row, dict)) + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs: + curs.execute("SELECT * FROM ExtrasDictCursorTests") + row = curs.fetchone() + self.assert_(isinstance(row, dict)) def testDictCursorWithPlainCursorRealFetchOne(self): self._testWithPlainCursorReal(lambda curs: curs.fetchone()) @@ -248,24 +260,24 @@ class ExtrasDictCursorRealTests(_DictCursorBase): @skip_before_postgres(8, 0) def testDictCursorWithPlainCursorRealIterRowNumber(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) - self._testIterRowNumber(curs) + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs: + self._testIterRowNumber(curs) def _testWithPlainCursorReal(self, getter): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) - curs.execute("SELECT * FROM ExtrasDictCursorTests") - row = getter(curs) - self.failUnless(row['foo'] == 'bar') + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs: + curs.execute("SELECT * FROM ExtrasDictCursorTests") + row = getter(curs) + self.failUnless(row['foo'] == 'bar') def testPickleRealDictRow(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) - curs.execute("select 10 as a, 20 as b") - r = curs.fetchone() - d = pickle.dumps(r) - r1 = pickle.loads(d) - self.assertEqual(r, r1) - self.assertEqual(r['a'], r1['a']) - self.assertEqual(r['b'], r1['b']) + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs: + curs.execute("select 10 as a, 20 as b") + r = curs.fetchone() + d = pickle.dumps(r) + r1 = pickle.loads(d) + self.assertEqual(r, r1) + self.assertEqual(r['a'], r1['a']) + self.assertEqual(r['b'], r1['b']) def testDictCursorRealWithNamedCursorFetchOne(self): self._testWithNamedCursorReal(lambda curs: curs.fetchone()) @@ -287,26 +299,34 @@ class ExtrasDictCursorRealTests(_DictCursorBase): @skip_before_postgres(8, 2) def testDictCursorRealWithNamedCursorNotGreedy(self): - curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.RealDictCursor) - self._testNamedCursorNotGreedy(curs) + with self.conn.cursor( + 'tmp', + cursor_factory=psycopg2.extras.RealDictCursor + ) as curs: + self._testNamedCursorNotGreedy(curs) @skip_before_postgres(8, 0) def testDictCursorRealWithNamedCursorIterRowNumber(self): - curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.RealDictCursor) - self._testIterRowNumber(curs) + with self.conn.cursor( + 'tmp', + cursor_factory=psycopg2.extras.RealDictCursor + ) as curs: + self._testIterRowNumber(curs) def _testWithNamedCursorReal(self, getter): - curs = self.conn.cursor('aname', - cursor_factory=psycopg2.extras.RealDictCursor) - curs.execute("SELECT * FROM ExtrasDictCursorTests") - row = getter(curs) - self.failUnless(row['foo'] == 'bar') + with self.conn.cursor( + 'aname', + cursor_factory=psycopg2.extras.RealDictCursor + ) as curs: + curs.execute("SELECT * FROM ExtrasDictCursorTests") + row = getter(curs) + self.failUnless(row['foo'] == 'bar') @skip_from_python(3) def test_iter_methods_2(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) - curs.execute("select 10 as a, 20 as b") - r = curs.fetchone() + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs: + curs.execute("select 10 as a, 20 as b") + r = curs.fetchone() self.assert_(isinstance(r.keys(), list)) self.assertEqual(len(r.keys()), 2) self.assert_(isinstance(r.values(), list)) @@ -323,9 +343,9 @@ class ExtrasDictCursorRealTests(_DictCursorBase): @skip_before_python(3) def test_iter_methods_3(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) - curs.execute("select 10 as a, 20 as b") - r = curs.fetchone() + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs: + curs.execute("select 10 as a, 20 as b") + r = curs.fetchone() self.assert_(not isinstance(r.keys(), list)) self.assertEqual(len(list(r.keys())), 2) self.assert_(not isinstance(r.values(), list)) @@ -334,9 +354,9 @@ class ExtrasDictCursorRealTests(_DictCursorBase): self.assertEqual(len(list(r.items())), 2) def test_order(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) - curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux") - r = curs.fetchone() + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs: + curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux") + r = curs.fetchone() self.assertEqual(list(r), ['foo', 'bar', 'baz', 'qux']) self.assertEqual(list(r.keys()), ['foo', 'bar', 'baz', 'qux']) self.assertEqual(list(r.values()), [5, 4, 33, 2]) @@ -351,9 +371,9 @@ class ExtrasDictCursorRealTests(_DictCursorBase): @skip_from_python(3) def test_order_iter(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) - curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux") - r = curs.fetchone() + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs: + curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux") + r = curs.fetchone() self.assertEqual(list(r.iterkeys()), ['foo', 'bar', 'baz', 'qux']) self.assertEqual(list(r.itervalues()), [5, 4, 33, 2]) self.assertEqual(list(r.iteritems()), @@ -365,9 +385,9 @@ class ExtrasDictCursorRealTests(_DictCursorBase): self.assertEqual(list(r1.iteritems()), list(r.iteritems())) def test_pop(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) - curs.execute("select 1 as a, 2 as b, 3 as c") - r = curs.fetchone() + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs: + curs.execute("select 1 as a, 2 as b, 3 as c") + r = curs.fetchone() self.assertEqual(r.pop('b'), 2) self.assertEqual(list(r), ['a', 'c']) self.assertEqual(list(r.keys()), ['a', 'c']) @@ -378,9 +398,9 @@ class ExtrasDictCursorRealTests(_DictCursorBase): self.assertRaises(KeyError, r.pop, 'b') def test_mod(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) - curs.execute("select 1 as a, 2 as b, 3 as c") - r = curs.fetchone() + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs: + curs.execute("select 1 as a, 2 as b, 3 as c") + r = curs.fetchone() r['d'] = 4 self.assertEqual(list(r), ['a', 'b', 'c', 'd']) self.assertEqual(list(r.keys()), ['a', 'b', 'c', 'd']) @@ -399,137 +419,141 @@ class NamedTupleCursorTest(ConnectingTestCase): ConnectingTestCase.setUp(self) self.conn = self.connect(connection_factory=NamedTupleConnection) - curs = self.conn.cursor() - curs.execute("CREATE TEMPORARY TABLE nttest (i int, s text)") - curs.execute("INSERT INTO nttest VALUES (1, 'foo')") - curs.execute("INSERT INTO nttest VALUES (2, 'bar')") - curs.execute("INSERT INTO nttest VALUES (3, 'baz')") + with self.conn.cursor() as curs: + curs.execute("CREATE TEMPORARY TABLE nttest (i int, s text)") + curs.execute("INSERT INTO nttest VALUES (1, 'foo')") + curs.execute("INSERT INTO nttest VALUES (2, 'bar')") + curs.execute("INSERT INTO nttest VALUES (3, 'baz')") self.conn.commit() + @skip_before_postgres(8, 2) def test_cursor_args(self): - cur = self.conn.cursor('foo', cursor_factory=psycopg2.extras.DictCursor) - self.assertEqual(cur.name, 'foo') - self.assert_(isinstance(cur, psycopg2.extras.DictCursor)) + with self.conn.cursor( + 'foo', + cursor_factory=psycopg2.extras.DictCursor + ) as cur: + self.assertEqual(cur.name, 'foo') + self.assert_(isinstance(cur, psycopg2.extras.DictCursor)) def test_fetchone(self): - curs = self.conn.cursor() - curs.execute("select * from nttest order by 1") - t = curs.fetchone() - self.assertEqual(t[0], 1) - self.assertEqual(t.i, 1) - self.assertEqual(t[1], 'foo') - self.assertEqual(t.s, 'foo') - self.assertEqual(curs.rownumber, 1) - self.assertEqual(curs.rowcount, 3) + with self.conn.cursor() as curs: + curs.execute("select * from nttest order by 1") + t = curs.fetchone() + self.assertEqual(t[0], 1) + self.assertEqual(t.i, 1) + self.assertEqual(t[1], 'foo') + self.assertEqual(t.s, 'foo') + self.assertEqual(curs.rownumber, 1) + self.assertEqual(curs.rowcount, 3) def test_fetchmany_noarg(self): - curs = self.conn.cursor() - curs.arraysize = 2 - curs.execute("select * from nttest order by 1") - res = curs.fetchmany() - self.assertEqual(2, len(res)) - self.assertEqual(res[0].i, 1) - self.assertEqual(res[0].s, 'foo') - self.assertEqual(res[1].i, 2) - self.assertEqual(res[1].s, 'bar') - self.assertEqual(curs.rownumber, 2) - self.assertEqual(curs.rowcount, 3) + with self.conn.cursor() as curs: + curs.arraysize = 2 + curs.execute("select * from nttest order by 1") + res = curs.fetchmany() + self.assertEqual(2, len(res)) + self.assertEqual(res[0].i, 1) + self.assertEqual(res[0].s, 'foo') + self.assertEqual(res[1].i, 2) + self.assertEqual(res[1].s, 'bar') + self.assertEqual(curs.rownumber, 2) + self.assertEqual(curs.rowcount, 3) def test_fetchmany(self): - curs = self.conn.cursor() - curs.execute("select * from nttest order by 1") - res = curs.fetchmany(2) - self.assertEqual(2, len(res)) - self.assertEqual(res[0].i, 1) - self.assertEqual(res[0].s, 'foo') - self.assertEqual(res[1].i, 2) - self.assertEqual(res[1].s, 'bar') - self.assertEqual(curs.rownumber, 2) - self.assertEqual(curs.rowcount, 3) + with self.conn.cursor() as curs: + curs.execute("select * from nttest order by 1") + res = curs.fetchmany(2) + self.assertEqual(2, len(res)) + self.assertEqual(res[0].i, 1) + self.assertEqual(res[0].s, 'foo') + self.assertEqual(res[1].i, 2) + self.assertEqual(res[1].s, 'bar') + self.assertEqual(curs.rownumber, 2) + self.assertEqual(curs.rowcount, 3) def test_fetchall(self): - curs = self.conn.cursor() - curs.execute("select * from nttest order by 1") - res = curs.fetchall() - self.assertEqual(3, len(res)) - self.assertEqual(res[0].i, 1) - self.assertEqual(res[0].s, 'foo') - self.assertEqual(res[1].i, 2) - self.assertEqual(res[1].s, 'bar') - self.assertEqual(res[2].i, 3) - self.assertEqual(res[2].s, 'baz') - self.assertEqual(curs.rownumber, 3) - self.assertEqual(curs.rowcount, 3) + with self.conn.cursor() as curs: + curs.execute("select * from nttest order by 1") + res = curs.fetchall() + self.assertEqual(3, len(res)) + self.assertEqual(res[0].i, 1) + self.assertEqual(res[0].s, 'foo') + self.assertEqual(res[1].i, 2) + self.assertEqual(res[1].s, 'bar') + self.assertEqual(res[2].i, 3) + self.assertEqual(res[2].s, 'baz') + self.assertEqual(curs.rownumber, 3) + self.assertEqual(curs.rowcount, 3) def test_executemany(self): - curs = self.conn.cursor() - curs.executemany("delete from nttest where i = %s", - [(1,), (2,)]) - curs.execute("select * from nttest order by 1") - res = curs.fetchall() + with self.conn.cursor() as curs: + curs.executemany("delete from nttest where i = %s", + [(1,), (2,)]) + curs.execute("select * from nttest order by 1") + res = curs.fetchall() self.assertEqual(1, len(res)) self.assertEqual(res[0].i, 3) self.assertEqual(res[0].s, 'baz') def test_iter(self): - curs = self.conn.cursor() - curs.execute("select * from nttest order by 1") - i = iter(curs) - self.assertEqual(curs.rownumber, 0) + with self.conn.cursor() as curs: + curs.execute("select * from nttest order by 1") + i = iter(curs) + self.assertEqual(curs.rownumber, 0) - t = next(i) - self.assertEqual(t.i, 1) - self.assertEqual(t.s, 'foo') - self.assertEqual(curs.rownumber, 1) - self.assertEqual(curs.rowcount, 3) + t = next(i) + self.assertEqual(t.i, 1) + self.assertEqual(t.s, 'foo') + self.assertEqual(curs.rownumber, 1) + self.assertEqual(curs.rowcount, 3) - t = next(i) - self.assertEqual(t.i, 2) - self.assertEqual(t.s, 'bar') - self.assertEqual(curs.rownumber, 2) - self.assertEqual(curs.rowcount, 3) + t = next(i) + self.assertEqual(t.i, 2) + self.assertEqual(t.s, 'bar') + self.assertEqual(curs.rownumber, 2) + self.assertEqual(curs.rowcount, 3) - t = next(i) - self.assertEqual(t.i, 3) - self.assertEqual(t.s, 'baz') - self.assertRaises(StopIteration, next, i) - self.assertEqual(curs.rownumber, 3) - self.assertEqual(curs.rowcount, 3) + t = next(i) + self.assertEqual(t.i, 3) + self.assertEqual(t.s, 'baz') + self.assertRaises(StopIteration, next, i) + self.assertEqual(curs.rownumber, 3) + self.assertEqual(curs.rowcount, 3) def test_record_updated(self): - curs = self.conn.cursor() - curs.execute("select 1 as foo;") - r = curs.fetchone() - self.assertEqual(r.foo, 1) + with self.conn.cursor() as curs: + curs.execute("select 1 as foo;") + r = curs.fetchone() + self.assertEqual(r.foo, 1) - curs.execute("select 2 as bar;") - r = curs.fetchone() - self.assertEqual(r.bar, 2) - self.assertRaises(AttributeError, getattr, r, 'foo') + curs.execute("select 2 as bar;") + r = curs.fetchone() + self.assertEqual(r.bar, 2) + self.assertRaises(AttributeError, getattr, r, 'foo') def test_no_result_no_surprise(self): - curs = self.conn.cursor() - curs.execute("update nttest set s = s") - self.assertRaises(psycopg2.ProgrammingError, curs.fetchone) + with self.conn.cursor() as curs: + curs.execute("update nttest set s = s") + self.assertRaises(psycopg2.ProgrammingError, curs.fetchone) - curs.execute("update nttest set s = s") - self.assertRaises(psycopg2.ProgrammingError, curs.fetchall) + curs.execute("update nttest set s = s") + self.assertRaises(psycopg2.ProgrammingError, curs.fetchall) def test_bad_col_names(self): - curs = self.conn.cursor() - curs.execute('select 1 as "foo.bar_baz", 2 as "?column?", 3 as "3"') - rv = curs.fetchone() - self.assertEqual(rv.foo_bar_baz, 1) - self.assertEqual(rv.f_column_, 2) - self.assertEqual(rv.f3, 3) + with self.conn.cursor() as curs: + curs.execute('select 1 as "foo.bar_baz", 2 as "?column?", 3 as "3"') + rv = curs.fetchone() + self.assertEqual(rv.foo_bar_baz, 1) + self.assertEqual(rv.f_column_, 2) + self.assertEqual(rv.f3, 3) @skip_before_python(3) @skip_before_postgres(8) def test_nonascii_name(self): - curs = self.conn.cursor() - curs.execute('select 1 as \xe5h\xe9') - rv = curs.fetchone() - self.assertEqual(getattr(rv, '\xe5h\xe9'), 1) + with self.conn.cursor() as curs: + curs.execute('select 1 as \xe5h\xe9') + rv = curs.fetchone() + self.assertEqual(getattr(rv, '\xe5h\xe9'), 1) def test_minimal_generation(self): # Instrument the class to verify it gets called the minimum number of times. @@ -543,91 +567,92 @@ class NamedTupleCursorTest(ConnectingTestCase): NamedTupleCursor._make_nt = f_patched try: - curs = self.conn.cursor() - curs.execute("select * from nttest order by 1") - curs.fetchone() - curs.fetchone() - curs.fetchone() - self.assertEqual(1, calls[0]) + with self.conn.cursor() as curs: + curs.execute("select * from nttest order by 1") + curs.fetchone() + curs.fetchone() + curs.fetchone() + self.assertEqual(1, calls[0]) - curs.execute("select * from nttest order by 1") - curs.fetchone() - curs.fetchall() - self.assertEqual(2, calls[0]) + curs.execute("select * from nttest order by 1") + curs.fetchone() + curs.fetchall() + self.assertEqual(2, calls[0]) - curs.execute("select * from nttest order by 1") - curs.fetchone() - curs.fetchmany(1) - self.assertEqual(3, calls[0]) + curs.execute("select * from nttest order by 1") + curs.fetchone() + curs.fetchmany(1) + self.assertEqual(3, calls[0]) finally: NamedTupleCursor._make_nt = f_orig @skip_before_postgres(8, 0) def test_named(self): - curs = self.conn.cursor('tmp') - curs.execute("""select i from generate_series(0,9) i""") - recs = [] - recs.extend(curs.fetchmany(5)) - recs.append(curs.fetchone()) - recs.extend(curs.fetchall()) - self.assertEqual(list(range(10)), [t.i for t in recs]) + with self.conn.cursor('tmp') as curs: + curs.execute("""select i from generate_series(0,9) i""") + recs = [] + recs.extend(curs.fetchmany(5)) + recs.append(curs.fetchone()) + recs.extend(curs.fetchall()) + self.assertEqual(list(range(10)), [t.i for t in recs]) def test_named_fetchone(self): - curs = self.conn.cursor('tmp') - curs.execute("""select 42 as i""") - t = curs.fetchone() - self.assertEqual(t.i, 42) + with self.conn.cursor('tmp') as curs: + curs.execute("""select 42 as i""") + t = curs.fetchone() + self.assertEqual(t.i, 42) def test_named_fetchmany(self): - curs = self.conn.cursor('tmp') - curs.execute("""select 42 as i""") - recs = curs.fetchmany(10) - self.assertEqual(recs[0].i, 42) + with self.conn.cursor('tmp') as curs: + curs.execute("""select 42 as i""") + recs = curs.fetchmany(10) + self.assertEqual(recs[0].i, 42) def test_named_fetchall(self): - curs = self.conn.cursor('tmp') - curs.execute("""select 42 as i""") - recs = curs.fetchall() - self.assertEqual(recs[0].i, 42) + with self.conn.cursor('tmp') as curs: + curs.execute("""select 42 as i""") + recs = curs.fetchall() + self.assertEqual(recs[0].i, 42) @skip_before_postgres(8, 2) def test_not_greedy(self): - curs = self.conn.cursor('tmp') - curs.itersize = 2 - curs.execute("""select clock_timestamp() as ts from generate_series(1,3)""") - recs = [] - for t in curs: - time.sleep(0.01) - recs.append(t) + with self.conn.cursor('tmp') as curs: + curs.itersize = 2 + curs.execute( + """select clock_timestamp() as ts from generate_series(1,3)""") + recs = [] + for t in curs: + time.sleep(0.01) + recs.append(t) - # check that the dataset was not fetched in a single gulp - self.assert_(recs[1].ts - recs[0].ts < timedelta(seconds=0.005)) - self.assert_(recs[2].ts - recs[1].ts > timedelta(seconds=0.0099)) + # check that the dataset was not fetched in a single gulp + self.assert_(recs[1].ts - recs[0].ts < timedelta(seconds=0.005)) + self.assert_(recs[2].ts - recs[1].ts > timedelta(seconds=0.0099)) @skip_before_postgres(8, 0) def test_named_rownumber(self): - curs = self.conn.cursor('tmp') - # Only checking for dataset < itersize: - # see CursorTests.test_iter_named_cursor_rownumber - curs.itersize = 4 - curs.execute("""select * from generate_series(1,3)""") - for i, t in enumerate(curs): - self.assertEqual(i + 1, curs.rownumber) + with self.conn.cursor('tmp') as curs: + # Only checking for dataset < itersize: + # see CursorTests.test_iter_named_cursor_rownumber + curs.itersize = 4 + curs.execute("""select * from generate_series(1,3)""") + for i, t in enumerate(curs): + self.assertEqual(i + 1, curs.rownumber) def test_cache(self): NamedTupleCursor._cached_make_nt.cache_clear() - curs = self.conn.cursor() - curs.execute("select 10 as a, 20 as b") - r1 = curs.fetchone() - curs.execute("select 10 as a, 20 as c") - r2 = curs.fetchone() + with self.conn.cursor() as curs: + curs.execute("select 10 as a, 20 as b") + r1 = curs.fetchone() + curs.execute("select 10 as a, 20 as c") + r2 = curs.fetchone() # Get a new cursor to check that the cache works across multiple ones - curs = self.conn.cursor() - curs.execute("select 10 as a, 30 as b") - r3 = curs.fetchone() + with self.conn.cursor() as curs: + curs.execute("select 10 as a, 30 as b") + r3 = curs.fetchone() self.assert_(type(r1) is type(r3)) self.assert_(type(r1) is not type(r2)) @@ -643,20 +668,20 @@ class NamedTupleCursorTest(ConnectingTestCase): lru_cache(8)(NamedTupleCursor._cached_make_nt.__wrapped__) try: recs = [] - curs = self.conn.cursor() - for i in range(10): - curs.execute("select 1 as f%s" % i) - recs.append(curs.fetchone()) + with self.conn.cursor() as curs: + for i in range(10): + curs.execute("select 1 as f%s" % i) + recs.append(curs.fetchone()) - # Still in cache - curs.execute("select 1 as f9") - rec = curs.fetchone() - self.assert_(any(type(r) is type(rec) for r in recs)) + # Still in cache + curs.execute("select 1 as f9") + rec = curs.fetchone() + self.assert_(any(type(r) is type(rec) for r in recs)) - # Gone from cache - curs.execute("select 1 as f0") - rec = curs.fetchone() - self.assert_(all(type(r) is not type(rec) for r in recs)) + # Gone from cache + curs.execute("select 1 as f0") + rec = curs.fetchone() + self.assert_(all(type(r) is not type(rec) for r in recs)) finally: NamedTupleCursor._cached_make_nt = old_func diff --git a/tests/test_fast_executemany.py b/tests/test_fast_executemany.py index eaba029c..8746303b 100755 --- a/tests/test_fast_executemany.py +++ b/tests/test_fast_executemany.py @@ -46,219 +46,219 @@ class TestPaginate(unittest.TestCase): class FastExecuteTestMixin(object): def setUp(self): super(FastExecuteTestMixin, self).setUp() - cur = self.conn.cursor() - cur.execute("""create table testfast ( - id serial primary key, date date, val int, data text)""") + with self.conn.cursor() as cur: + cur.execute("""create table testfast ( + id serial primary key, date date, val int, data text)""") class TestExecuteBatch(FastExecuteTestMixin, testutils.ConnectingTestCase): def test_empty(self): - cur = self.conn.cursor() - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, val) values (%s, %s)", - []) - cur.execute("select * from testfast order by id") - self.assertEqual(cur.fetchall(), []) + with self.conn.cursor() as cur: + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, val) values (%s, %s)", + []) + cur.execute("select * from testfast order by id") + self.assertEqual(cur.fetchall(), []) def test_one(self): - cur = self.conn.cursor() - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, val) values (%s, %s)", - iter([(1, 10)])) - cur.execute("select id, val from testfast order by id") - self.assertEqual(cur.fetchall(), [(1, 10)]) + with self.conn.cursor() as cur: + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, val) values (%s, %s)", + iter([(1, 10)])) + cur.execute("select id, val from testfast order by id") + self.assertEqual(cur.fetchall(), [(1, 10)]) def test_tuples(self): - cur = self.conn.cursor() - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, date, val) values (%s, %s, %s)", - ((i, date(2017, 1, i + 1), i * 10) for i in range(10))) - cur.execute("select id, date, val from testfast order by id") - self.assertEqual(cur.fetchall(), - [(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) + with self.conn.cursor() as cur: + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, date, val) values (%s, %s, %s)", + ((i, date(2017, 1, i + 1), i * 10) for i in range(10))) + cur.execute("select id, date, val from testfast order by id") + self.assertEqual(cur.fetchall(), + [(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) def test_many(self): - cur = self.conn.cursor() - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, val) values (%s, %s)", - ((i, i * 10) for i in range(1000))) - cur.execute("select id, val from testfast order by id") - self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) + with self.conn.cursor() as cur: + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, val) values (%s, %s)", + ((i, i * 10) for i in range(1000))) + cur.execute("select id, val from testfast order by id") + self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) def test_composed(self): - cur = self.conn.cursor() - psycopg2.extras.execute_batch(cur, - sql.SQL("insert into {0} (id, val) values (%s, %s)") - .format(sql.Identifier('testfast')), - ((i, i * 10) for i in range(1000))) - cur.execute("select id, val from testfast order by id") - self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) + with self.conn.cursor() as cur: + psycopg2.extras.execute_batch(cur, + sql.SQL("insert into {0} (id, val) values (%s, %s)") + .format(sql.Identifier('testfast')), + ((i, i * 10) for i in range(1000))) + cur.execute("select id, val from testfast order by id") + self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) def test_pages(self): - cur = self.conn.cursor() - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, val) values (%s, %s)", - ((i, i * 10) for i in range(25)), - page_size=10) + with self.conn.cursor() as cur: + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, val) values (%s, %s)", + ((i, i * 10) for i in range(25)), + page_size=10) - # last command was 5 statements - self.assertEqual(sum(c == u';' for c in cur.query.decode('ascii')), 4) + # last command was 5 statements + self.assertEqual(sum(c == u';' for c in cur.query.decode('ascii')), 4) - cur.execute("select id, val from testfast order by id") - self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)]) + cur.execute("select id, val from testfast order by id") + self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)]) @testutils.skip_before_postgres(8, 0) def test_unicode(self): - cur = self.conn.cursor() - ext.register_type(ext.UNICODE, cur) - snowman = u"\u2603" + with self.conn.cursor() as cur: + ext.register_type(ext.UNICODE, cur) + snowman = u"\u2603" - # unicode in statement - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman, - [(1, 'x')]) - cur.execute("select id, data from testfast where id = 1") - self.assertEqual(cur.fetchone(), (1, 'x')) + # unicode in statement + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman, + [(1, 'x')]) + cur.execute("select id, data from testfast where id = 1") + self.assertEqual(cur.fetchone(), (1, 'x')) - # unicode in data - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, data) values (%s, %s)", - [(2, snowman)]) - cur.execute("select id, data from testfast where id = 2") - self.assertEqual(cur.fetchone(), (2, snowman)) + # unicode in data + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, data) values (%s, %s)", + [(2, snowman)]) + cur.execute("select id, data from testfast where id = 2") + self.assertEqual(cur.fetchone(), (2, snowman)) - # unicode in both - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman, - [(3, snowman)]) - cur.execute("select id, data from testfast where id = 3") - self.assertEqual(cur.fetchone(), (3, snowman)) + # unicode in both + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman, + [(3, snowman)]) + cur.execute("select id, data from testfast where id = 3") + self.assertEqual(cur.fetchone(), (3, snowman)) @testutils.skip_before_postgres(8, 2) class TestExecuteValues(FastExecuteTestMixin, testutils.ConnectingTestCase): def test_empty(self): - cur = self.conn.cursor() - psycopg2.extras.execute_values(cur, - "insert into testfast (id, val) values %s", - []) - cur.execute("select * from testfast order by id") - self.assertEqual(cur.fetchall(), []) + with self.conn.cursor() as cur: + psycopg2.extras.execute_values(cur, + "insert into testfast (id, val) values %s", + []) + cur.execute("select * from testfast order by id") + self.assertEqual(cur.fetchall(), []) def test_one(self): - cur = self.conn.cursor() - psycopg2.extras.execute_values(cur, - "insert into testfast (id, val) values %s", - iter([(1, 10)])) - cur.execute("select id, val from testfast order by id") - self.assertEqual(cur.fetchall(), [(1, 10)]) + with self.conn.cursor() as cur: + psycopg2.extras.execute_values(cur, + "insert into testfast (id, val) values %s", + iter([(1, 10)])) + cur.execute("select id, val from testfast order by id") + self.assertEqual(cur.fetchall(), [(1, 10)]) def test_tuples(self): - cur = self.conn.cursor() - psycopg2.extras.execute_values(cur, - "insert into testfast (id, date, val) values %s", - ((i, date(2017, 1, i + 1), i * 10) for i in range(10))) - cur.execute("select id, date, val from testfast order by id") - self.assertEqual(cur.fetchall(), - [(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) + with self.conn.cursor() as cur: + psycopg2.extras.execute_values(cur, + "insert into testfast (id, date, val) values %s", + ((i, date(2017, 1, i + 1), i * 10) for i in range(10))) + cur.execute("select id, date, val from testfast order by id") + self.assertEqual(cur.fetchall(), + [(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) def test_dicts(self): - cur = self.conn.cursor() - psycopg2.extras.execute_values(cur, - "insert into testfast (id, date, val) values %s", - (dict(id=i, date=date(2017, 1, i + 1), val=i * 10, foo="bar") - for i in range(10)), - template='(%(id)s, %(date)s, %(val)s)') - cur.execute("select id, date, val from testfast order by id") - self.assertEqual(cur.fetchall(), - [(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) + with self.conn.cursor() as cur: + psycopg2.extras.execute_values(cur, + "insert into testfast (id, date, val) values %s", + (dict(id=i, date=date(2017, 1, i + 1), val=i * 10, foo="bar") + for i in range(10)), + template='(%(id)s, %(date)s, %(val)s)') + cur.execute("select id, date, val from testfast order by id") + self.assertEqual(cur.fetchall(), + [(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) def test_many(self): - cur = self.conn.cursor() - psycopg2.extras.execute_values(cur, - "insert into testfast (id, val) values %s", - ((i, i * 10) for i in range(1000))) - cur.execute("select id, val from testfast order by id") - self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) + with self.conn.cursor() as cur: + psycopg2.extras.execute_values(cur, + "insert into testfast (id, val) values %s", + ((i, i * 10) for i in range(1000))) + cur.execute("select id, val from testfast order by id") + self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) def test_composed(self): - cur = self.conn.cursor() - psycopg2.extras.execute_values(cur, - sql.SQL("insert into {0} (id, val) values %s") - .format(sql.Identifier('testfast')), - ((i, i * 10) for i in range(1000))) - cur.execute("select id, val from testfast order by id") - self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) + with self.conn.cursor() as cur: + psycopg2.extras.execute_values(cur, + sql.SQL("insert into {0} (id, val) values %s") + .format(sql.Identifier('testfast')), + ((i, i * 10) for i in range(1000))) + cur.execute("select id, val from testfast order by id") + self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) def test_pages(self): - cur = self.conn.cursor() - psycopg2.extras.execute_values(cur, - "insert into testfast (id, val) values %s", - ((i, i * 10) for i in range(25)), - page_size=10) + with self.conn.cursor() as cur: + psycopg2.extras.execute_values(cur, + "insert into testfast (id, val) values %s", + ((i, i * 10) for i in range(25)), + page_size=10) - # last statement was 5 tuples (one parens is for the fields list) - self.assertEqual(sum(c == '(' for c in cur.query.decode('ascii')), 6) + # last statement was 5 tuples (one parens is for the fields list) + self.assertEqual(sum(c == '(' for c in cur.query.decode('ascii')), 6) - cur.execute("select id, val from testfast order by id") - self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)]) + cur.execute("select id, val from testfast order by id") + self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)]) def test_unicode(self): - cur = self.conn.cursor() - ext.register_type(ext.UNICODE, cur) - snowman = u"\u2603" + with self.conn.cursor() as cur: + ext.register_type(ext.UNICODE, cur) + snowman = u"\u2603" - # unicode in statement - psycopg2.extras.execute_values(cur, - "insert into testfast (id, data) values %%s -- %s" % snowman, - [(1, 'x')]) - cur.execute("select id, data from testfast where id = 1") - self.assertEqual(cur.fetchone(), (1, 'x')) + # unicode in statement + psycopg2.extras.execute_values(cur, + "insert into testfast (id, data) values %%s -- %s" % snowman, + [(1, 'x')]) + cur.execute("select id, data from testfast where id = 1") + self.assertEqual(cur.fetchone(), (1, 'x')) - # unicode in data - psycopg2.extras.execute_values(cur, - "insert into testfast (id, data) values %s", - [(2, snowman)]) - cur.execute("select id, data from testfast where id = 2") - self.assertEqual(cur.fetchone(), (2, snowman)) + # unicode in data + psycopg2.extras.execute_values(cur, + "insert into testfast (id, data) values %s", + [(2, snowman)]) + cur.execute("select id, data from testfast where id = 2") + self.assertEqual(cur.fetchone(), (2, snowman)) - # unicode in both - psycopg2.extras.execute_values(cur, - "insert into testfast (id, data) values %%s -- %s" % snowman, - [(3, snowman)]) - cur.execute("select id, data from testfast where id = 3") - self.assertEqual(cur.fetchone(), (3, snowman)) + # unicode in both + psycopg2.extras.execute_values(cur, + "insert into testfast (id, data) values %%s -- %s" % snowman, + [(3, snowman)]) + cur.execute("select id, data from testfast where id = 3") + self.assertEqual(cur.fetchone(), (3, snowman)) def test_returning(self): - cur = self.conn.cursor() - result = psycopg2.extras.execute_values(cur, - "insert into testfast (id, val) values %s returning id", - ((i, i * 10) for i in range(25)), - page_size=10, fetch=True) - # result contains all returned pages - self.assertEqual([r[0] for r in result], list(range(25))) + with self.conn.cursor() as cur: + result = psycopg2.extras.execute_values(cur, + "insert into testfast (id, val) values %s returning id", + ((i, i * 10) for i in range(25)), + page_size=10, fetch=True) + # result contains all returned pages + self.assertEqual([r[0] for r in result], list(range(25))) def test_invalid_sql(self): - cur = self.conn.cursor() - self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, - "insert", []) - self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, - "insert %s and %s", []) - self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, - "insert %f", []) - self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, - "insert %f %s", []) + with self.conn.cursor() as cur: + self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, + "insert", []) + self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, + "insert %s and %s", []) + self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, + "insert %f", []) + self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, + "insert %f %s", []) def test_percent_escape(self): - cur = self.conn.cursor() - psycopg2.extras.execute_values(cur, - "insert into testfast (id, data) values %s -- a%%b", - [(1, 'hi')]) - self.assert_(b'a%%b' not in cur.query) - self.assert_(b'a%b' in cur.query) + with self.conn.cursor() as cur: + psycopg2.extras.execute_values(cur, + "insert into testfast (id, data) values %s -- a%%b", + [(1, 'hi')]) + self.assert_(b'a%%b' not in cur.query) + self.assert_(b'a%b' in cur.query) - cur.execute("select id, data from testfast") - self.assertEqual(cur.fetchall(), [(1, 'hi')]) + cur.execute("select id, data from testfast") + self.assertEqual(cur.fetchall(), [(1, 'hi')]) def test_suite(): diff --git a/tests/test_green.py b/tests/test_green.py index e56ce586..635a1907 100755 --- a/tests/test_green.py +++ b/tests/test_green.py @@ -71,14 +71,14 @@ class GreenTestCase(ConnectingTestCase): # a very large query requires a flush loop to be sent to the backend conn = self.conn stub = self.set_stub_wait_callback(conn) - curs = conn.cursor() - for mb in 1, 5, 10, 20, 50: - size = mb * 1024 * 1024 - del stub.polls[:] - curs.execute("select %s;", ('x' * size,)) - self.assertEqual(size, len(curs.fetchone()[0])) - if stub.polls.count(psycopg2.extensions.POLL_WRITE) > 1: - return + with conn.cursor() as curs: + for mb in 1, 5, 10, 20, 50: + size = mb * 1024 * 1024 + del stub.polls[:] + curs.execute("select %s;", ('x' * size,)) + self.assertEqual(size, len(curs.fetchone()[0])) + if stub.polls.count(psycopg2.extensions.POLL_WRITE) > 1: + return # This is more a testing glitch than an error: it happens # on high load on linux: probably because the kernel has more @@ -105,21 +105,21 @@ class GreenTestCase(ConnectingTestCase): # if there is an error in a green query, don't freak out and close # the connection conn = self.conn - curs = conn.cursor() - self.assertRaises(psycopg2.ProgrammingError, - curs.execute, "select the unselectable") + with conn.cursor() as curs: + self.assertRaises(psycopg2.ProgrammingError, + curs.execute, "select the unselectable") - # check that the connection is left in an usable state - self.assert_(not conn.closed) - conn.rollback() - curs.execute("select 1") - self.assertEqual(curs.fetchone()[0], 1) + # check that the connection is left in an usable state + self.assert_(not conn.closed) + conn.rollback() + curs.execute("select 1") + self.assertEqual(curs.fetchone()[0], 1) @skip_before_postgres(8, 2) def test_copy_no_hang(self): - cur = self.conn.cursor() - self.assertRaises(psycopg2.ProgrammingError, - cur.execute, "copy (select 1) to stdout") + with self.conn.cursor() as cur: + self.assertRaises(psycopg2.ProgrammingError, + cur.execute, "copy (select 1) to stdout") @slow @skip_before_postgres(9, 0) @@ -137,19 +137,19 @@ class GreenTestCase(ConnectingTestCase): raise conn.OperationalError("bad state from poll: %s" % state) stub = self.set_stub_wait_callback(self.conn, wait) - cur = self.conn.cursor() - cur.execute(""" - select 1; - do $$ - begin - raise notice 'hello'; - end - $$ language plpgsql; - select pg_sleep(1); - """) + with self.conn.cursor() as cur: + cur.execute(""" + select 1; + do $$ + begin + raise notice 'hello'; + end + $$ language plpgsql; + select pg_sleep(1); + """) - polls = stub.polls.count(POLL_READ) - self.assert_(polls > 8, polls) + polls = stub.polls.count(POLL_READ) + self.assert_(polls > 8, polls) class CallbackErrorTestCase(ConnectingTestCase): @@ -203,16 +203,16 @@ class CallbackErrorTestCase(ConnectingTestCase): for i in range(100): self.to_error = None cnn = self.connect() - cur = cnn.cursor() - self.to_error = i - try: - cur.execute("select 1") - cur.fetchone() - except ZeroDivisionError: - pass - else: - # The query completed - return + with cnn.cursor() as cur: + self.to_error = i + try: + cur.execute("select 1") + cur.fetchone() + except ZeroDivisionError: + pass + else: + # The query completed + return self.fail("you should have had a success or an error by now") @@ -220,16 +220,19 @@ class CallbackErrorTestCase(ConnectingTestCase): for i in range(100): self.to_error = None cnn = self.connect() - cur = cnn.cursor('foo') - self.to_error = i - try: - cur.execute("select 1") - cur.fetchone() - except ZeroDivisionError: - pass - else: - # The query completed - return + with cnn.cursor('foo') as cur: + self.to_error = i + try: + cur.execute("select 1") + cur.fetchone() + except ZeroDivisionError: + pass + else: + # The query completed + return + finally: + # Don't raise an exception in the cursor context manager. + self.to_error = None self.fail("you should have had a success or an error by now") diff --git a/tests/test_ipaddress.py b/tests/test_ipaddress.py index ccbae291..ff92a711 100755 --- a/tests/test_ipaddress.py +++ b/tests/test_ipaddress.py @@ -33,82 +33,82 @@ except ImportError: @unittest.skipIf(ip is None, "'ipaddress' module not available") class NetworkingTestCase(testutils.ConnectingTestCase): def test_inet_cast(self): - cur = self.conn.cursor() - psycopg2.extras.register_ipaddress(cur) + with self.conn.cursor() as cur: + psycopg2.extras.register_ipaddress(cur) - cur.execute("select null::inet") - self.assert_(cur.fetchone()[0] is None) + cur.execute("select null::inet") + self.assert_(cur.fetchone()[0] is None) - cur.execute("select '127.0.0.1/24'::inet") - obj = cur.fetchone()[0] - self.assert_(isinstance(obj, ip.IPv4Interface), repr(obj)) - self.assertEquals(obj, ip.ip_interface('127.0.0.1/24')) + cur.execute("select '127.0.0.1/24'::inet") + obj = cur.fetchone()[0] + self.assert_(isinstance(obj, ip.IPv4Interface), repr(obj)) + self.assertEquals(obj, ip.ip_interface('127.0.0.1/24')) - cur.execute("select '::ffff:102:300/128'::inet") - obj = cur.fetchone()[0] - self.assert_(isinstance(obj, ip.IPv6Interface), repr(obj)) - self.assertEquals(obj, ip.ip_interface('::ffff:102:300/128')) + cur.execute("select '::ffff:102:300/128'::inet") + obj = cur.fetchone()[0] + self.assert_(isinstance(obj, ip.IPv6Interface), repr(obj)) + self.assertEquals(obj, ip.ip_interface('::ffff:102:300/128')) @testutils.skip_before_postgres(8, 2) def test_inet_array_cast(self): - cur = self.conn.cursor() - psycopg2.extras.register_ipaddress(cur) - cur.execute("select '{NULL,127.0.0.1,::ffff:102:300/128}'::inet[]") - l = cur.fetchone()[0] - self.assert_(l[0] is None) - self.assertEquals(l[1], ip.ip_interface('127.0.0.1')) - self.assertEquals(l[2], ip.ip_interface('::ffff:102:300/128')) - self.assert_(isinstance(l[1], ip.IPv4Interface), l) - self.assert_(isinstance(l[2], ip.IPv6Interface), l) + with self.conn.cursor() as cur: + psycopg2.extras.register_ipaddress(cur) + cur.execute("select '{NULL,127.0.0.1,::ffff:102:300/128}'::inet[]") + l = cur.fetchone()[0] + self.assert_(l[0] is None) + self.assertEquals(l[1], ip.ip_interface('127.0.0.1')) + self.assertEquals(l[2], ip.ip_interface('::ffff:102:300/128')) + self.assert_(isinstance(l[1], ip.IPv4Interface), l) + self.assert_(isinstance(l[2], ip.IPv6Interface), l) def test_inet_adapt(self): - cur = self.conn.cursor() - psycopg2.extras.register_ipaddress(cur) + with self.conn.cursor() as cur: + psycopg2.extras.register_ipaddress(cur) - cur.execute("select %s", [ip.ip_interface('127.0.0.1/24')]) - self.assertEquals(cur.fetchone()[0], '127.0.0.1/24') + cur.execute("select %s", [ip.ip_interface('127.0.0.1/24')]) + self.assertEquals(cur.fetchone()[0], '127.0.0.1/24') - cur.execute("select %s", [ip.ip_interface('::ffff:102:300/128')]) - self.assertEquals(cur.fetchone()[0], '::ffff:102:300/128') + cur.execute("select %s", [ip.ip_interface('::ffff:102:300/128')]) + self.assertEquals(cur.fetchone()[0], '::ffff:102:300/128') def test_cidr_cast(self): - cur = self.conn.cursor() - psycopg2.extras.register_ipaddress(cur) + with self.conn.cursor() as cur: + psycopg2.extras.register_ipaddress(cur) - cur.execute("select null::cidr") - self.assert_(cur.fetchone()[0] is None) + cur.execute("select null::cidr") + self.assert_(cur.fetchone()[0] is None) - cur.execute("select '127.0.0.0/24'::cidr") - obj = cur.fetchone()[0] - self.assert_(isinstance(obj, ip.IPv4Network), repr(obj)) - self.assertEquals(obj, ip.ip_network('127.0.0.0/24')) + cur.execute("select '127.0.0.0/24'::cidr") + obj = cur.fetchone()[0] + self.assert_(isinstance(obj, ip.IPv4Network), repr(obj)) + self.assertEquals(obj, ip.ip_network('127.0.0.0/24')) - cur.execute("select '::ffff:102:300/128'::cidr") - obj = cur.fetchone()[0] - self.assert_(isinstance(obj, ip.IPv6Network), repr(obj)) - self.assertEquals(obj, ip.ip_network('::ffff:102:300/128')) + cur.execute("select '::ffff:102:300/128'::cidr") + obj = cur.fetchone()[0] + self.assert_(isinstance(obj, ip.IPv6Network), repr(obj)) + self.assertEquals(obj, ip.ip_network('::ffff:102:300/128')) @testutils.skip_before_postgres(8, 2) def test_cidr_array_cast(self): - cur = self.conn.cursor() - psycopg2.extras.register_ipaddress(cur) - cur.execute("select '{NULL,127.0.0.1,::ffff:102:300/128}'::cidr[]") - l = cur.fetchone()[0] - self.assert_(l[0] is None) - self.assertEquals(l[1], ip.ip_network('127.0.0.1')) - self.assertEquals(l[2], ip.ip_network('::ffff:102:300/128')) - self.assert_(isinstance(l[1], ip.IPv4Network), l) - self.assert_(isinstance(l[2], ip.IPv6Network), l) + with self.conn.cursor() as cur: + psycopg2.extras.register_ipaddress(cur) + cur.execute("select '{NULL,127.0.0.1,::ffff:102:300/128}'::cidr[]") + l = cur.fetchone()[0] + self.assert_(l[0] is None) + self.assertEquals(l[1], ip.ip_network('127.0.0.1')) + self.assertEquals(l[2], ip.ip_network('::ffff:102:300/128')) + self.assert_(isinstance(l[1], ip.IPv4Network), l) + self.assert_(isinstance(l[2], ip.IPv6Network), l) def test_cidr_adapt(self): - cur = self.conn.cursor() - psycopg2.extras.register_ipaddress(cur) + with self.conn.cursor() as cur: + psycopg2.extras.register_ipaddress(cur) - cur.execute("select %s", [ip.ip_network('127.0.0.0/24')]) - self.assertEquals(cur.fetchone()[0], '127.0.0.0/24') + cur.execute("select %s", [ip.ip_network('127.0.0.0/24')]) + self.assertEquals(cur.fetchone()[0], '127.0.0.0/24') - cur.execute("select %s", [ip.ip_network('::ffff:102:300/128')]) - self.assertEquals(cur.fetchone()[0], '::ffff:102:300/128') + cur.execute("select %s", [ip.ip_network('::ffff:102:300/128')]) + self.assertEquals(cur.fetchone()[0], '::ffff:102:300/128') def test_suite(): diff --git a/tests/test_module.py b/tests/test_module.py index 416e6237..85153fdf 100755 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -154,22 +154,22 @@ class ConnectTestCase(unittest.TestCase): class ExceptionsTestCase(ConnectingTestCase): def test_attributes(self): - cur = self.conn.cursor() - try: - cur.execute("select * from nonexist") - except psycopg2.Error as exc: - e = exc + with self.conn.cursor() as cur: + try: + cur.execute("select * from nonexist") + except psycopg2.Error as exc: + e = exc self.assertEqual(e.pgcode, '42P01') self.assert_(e.pgerror) self.assert_(e.cursor is cur) def test_diagnostics_attributes(self): - cur = self.conn.cursor() - try: - cur.execute("select * from nonexist") - except psycopg2.Error as exc: - e = exc + with self.conn.cursor() as cur: + try: + cur.execute("select * from nonexist") + except psycopg2.Error as exc: + e = exc diag = e.diag self.assert_(isinstance(diag, psycopg2.extensions.Diagnostics)) @@ -195,11 +195,11 @@ class ExceptionsTestCase(ConnectingTestCase): def test_diagnostics_life(self): def tmp(): - cur = self.conn.cursor() - try: - cur.execute("select * from nonexist") - except psycopg2.Error as exc: - return cur, exc + with self.conn.cursor() as cur: + try: + cur.execute("select * from nonexist") + except psycopg2.Error as exc: + return cur, exc cur, e = tmp() diag = e.diag diff --git a/tests/test_notify.py b/tests/test_notify.py index 51865e64..ce3e0b03 100755 --- a/tests/test_notify.py +++ b/tests/test_notify.py @@ -120,10 +120,11 @@ conn.close() self.listen('foo') pid = int(self.notify('foo').communicate()[0]) self.assertEqual(0, len(self.conn.notifies)) - self.conn.cursor().execute('select 1;') - self.assertEqual(1, len(self.conn.notifies)) - self.assertEqual(pid, self.conn.notifies[0][0]) - self.assertEqual('foo', self.conn.notifies[0][1]) + with self.conn.cursor() as cur: + cur.execute('select 1;') + self.assertEqual(1, len(self.conn.notifies)) + self.assertEqual(pid, self.conn.notifies[0][0]) + self.assertEqual('foo', self.conn.notifies[0][1]) @slow def test_notify_object(self): diff --git a/tests/test_quote.py b/tests/test_quote.py index 42e90eb7..7d3759a8 100755 --- a/tests/test_quote.py +++ b/tests/test_quote.py @@ -56,24 +56,24 @@ class QuotingTestCase(ConnectingTestCase): """ data += "".join(map(chr, range(1, 127))) - curs = self.conn.cursor() - curs.execute("SELECT %s;", (data,)) - res = curs.fetchone()[0] + with self.conn.cursor() as curs: + curs.execute("SELECT %s;", (data,)) + res = curs.fetchone()[0] self.assertEqual(res, data) self.assert_(not self.conn.notices) def test_string_null_terminator(self): - curs = self.conn.cursor() - data = 'abcd\x01\x00cdefg' + with self.conn.cursor() as curs: + data = 'abcd\x01\x00cdefg' - try: - curs.execute("SELECT %s", (data,)) - except ValueError as e: - self.assertEquals(str(e), - 'A string literal cannot contain NUL (0x00) characters.') - else: - self.fail("ValueError not raised") + try: + curs.execute("SELECT %s", (data,)) + except ValueError as e: + self.assertEquals(str(e), + 'A string literal cannot contain NUL (0x00) characters.') + else: + self.fail("ValueError not raised") def test_binary(self): data = b"""some data with \000\013 binary @@ -84,12 +84,12 @@ class QuotingTestCase(ConnectingTestCase): else: data += bytes(list(range(256))) - curs = self.conn.cursor() - curs.execute("SELECT %s::bytea;", (psycopg2.Binary(data),)) - if PY2: - res = str(curs.fetchone()[0]) - else: - res = curs.fetchone()[0].tobytes() + with self.conn.cursor() as curs: + curs.execute("SELECT %s::bytea;", (psycopg2.Binary(data),)) + if PY2: + res = str(curs.fetchone()[0]) + else: + res = curs.fetchone()[0].tobytes() if res[0] in (b'x', ord(b'x')) and self.conn.info.server_version >= 90000: return self.skipTest( @@ -99,86 +99,87 @@ class QuotingTestCase(ConnectingTestCase): self.assert_(not self.conn.notices) def test_unicode(self): - curs = self.conn.cursor() - curs.execute("SHOW server_encoding") - server_encoding = curs.fetchone()[0] - if server_encoding != "UTF8": - return self.skipTest( - "Unicode test skipped since server encoding is %s" - % server_encoding) + with self.conn.cursor() as curs: + curs.execute("SHOW server_encoding") + server_encoding = curs.fetchone()[0] + if server_encoding != "UTF8": + return self.skipTest( + "Unicode test skipped since server encoding is %s" + % server_encoding) - data = u"""some data with \t chars - to escape into, 'quotes', \u20ac euro sign and \\ a backslash too. - """ - data += u"".join(map(unichr, [u for u in range(1, 65536) - if not 0xD800 <= u <= 0xDFFF])) # surrogate area - self.conn.set_client_encoding('UNICODE') + data = u"""some data with \t chars + to escape into, 'quotes', \u20ac euro sign and \\ a backslash too. + """ + data += u"".join(map(unichr, [u for u in range(1, 65536) + if not 0xD800 <= u <= 0xDFFF])) # surrogate area + self.conn.set_client_encoding('UNICODE') - psycopg2.extensions.register_type(psycopg2.extensions.UNICODE, self.conn) - curs.execute("SELECT %s::text;", (data,)) - res = curs.fetchone()[0] + psycopg2.extensions.register_type(psycopg2.extensions.UNICODE, self.conn) + curs.execute("SELECT %s::text;", (data,)) + res = curs.fetchone()[0] - self.assertEqual(res, data) - self.assert_(not self.conn.notices) + self.assertEqual(res, data) + self.assert_(not self.conn.notices) def test_latin1(self): self.conn.set_client_encoding('LATIN1') - curs = self.conn.cursor() - if PY2: - data = ''.join(map(chr, range(32, 127) + range(160, 256))) - else: - data = bytes(list(range(32, 127)) - + list(range(160, 256))).decode('latin1') - - # as string - curs.execute("SELECT %s::text;", (data,)) - res = curs.fetchone()[0] - self.assertEqual(res, data) - self.assert_(not self.conn.notices) - - # as unicode - if PY2: - psycopg2.extensions.register_type(psycopg2.extensions.UNICODE, self.conn) - data = data.decode('latin1') + with self.conn.cursor() as curs: + if PY2: + data = ''.join(map(chr, range(32, 127) + range(160, 256))) + else: + data = bytes(list(range(32, 127)) + + list(range(160, 256))).decode('latin1') + # as string curs.execute("SELECT %s::text;", (data,)) res = curs.fetchone()[0] self.assertEqual(res, data) self.assert_(not self.conn.notices) + # as unicode + if PY2: + psycopg2.extensions.register_type( + psycopg2.extensions.UNICODE, self.conn) + data = data.decode('latin1') + + curs.execute("SELECT %s::text;", (data,)) + res = curs.fetchone()[0] + self.assertEqual(res, data) + self.assert_(not self.conn.notices) + def test_koi8(self): self.conn.set_client_encoding('KOI8') - curs = self.conn.cursor() - if PY2: - data = ''.join(map(chr, range(32, 127) + range(128, 256))) - else: - data = bytes(list(range(32, 127)) - + list(range(128, 256))).decode('koi8_r') - - # as string - curs.execute("SELECT %s::text;", (data,)) - res = curs.fetchone()[0] - self.assertEqual(res, data) - self.assert_(not self.conn.notices) - - # as unicode - if PY2: - psycopg2.extensions.register_type(psycopg2.extensions.UNICODE, self.conn) - data = data.decode('koi8_r') + with self.conn.cursor() as curs: + if PY2: + data = ''.join(map(chr, range(32, 127) + range(128, 256))) + else: + data = bytes(list(range(32, 127)) + + list(range(128, 256))).decode('koi8_r') + # as string curs.execute("SELECT %s::text;", (data,)) res = curs.fetchone()[0] self.assertEqual(res, data) self.assert_(not self.conn.notices) + # as unicode + if PY2: + psycopg2.extensions.register_type(psycopg2.extensions.UNICODE, self.conn) + data = data.decode('koi8_r') + + curs.execute("SELECT %s::text;", (data,)) + res = curs.fetchone()[0] + self.assertEqual(res, data) + self.assert_(not self.conn.notices) + def test_bytes(self): snowman = u"\u2603" conn = self.connect() conn.set_client_encoding('UNICODE') psycopg2.extensions.register_type(psycopg2.extensions.BYTES, conn) - curs = conn.cursor() - curs.execute("select %s::text", (snowman,)) - x = curs.fetchone()[0] + with conn.cursor() as curs: + curs.execute("select %s::text", (snowman,)) + x = curs.fetchone()[0] self.assert_(isinstance(x, bytes)) self.assertEqual(x, snowman.encode('utf8')) diff --git a/tests/test_replication.py b/tests/test_replication.py index 3ed68a57..e5564eb2 100755 --- a/tests/test_replication.py +++ b/tests/test_replication.py @@ -73,14 +73,13 @@ class ReplicationTestCase(ConnectingTestCase): conn = self.connect() if conn is None: return - cur = conn.cursor() - - try: - cur.execute("DROP TABLE dummy1") - except psycopg2.ProgrammingError: - conn.rollback() - cur.execute( - "CREATE TABLE dummy1 AS SELECT * FROM generate_series(1, 5) AS id") + with conn.cursor() as cur: + try: + cur.execute("DROP TABLE dummy1") + except psycopg2.ProgrammingError: + conn.rollback() + cur.execute( + "CREATE TABLE dummy1 AS SELECT * FROM generate_series(1, 5) AS id") conn.commit() @@ -90,9 +89,9 @@ class ReplicationTest(ReplicationTestCase): conn = self.repl_connect(connection_factory=PhysicalReplicationConnection) if conn is None: return - cur = conn.cursor() - cur.execute("IDENTIFY_SYSTEM") - cur.fetchall() + with conn.cursor() as cur: + cur.execute("IDENTIFY_SYSTEM") + cur.fetchall() @skip_before_postgres(9, 0) def test_datestyle(self): @@ -104,29 +103,28 @@ class ReplicationTest(ReplicationTestCase): connection_factory=PhysicalReplicationConnection) if conn is None: return - cur = conn.cursor() - cur.execute("IDENTIFY_SYSTEM") - cur.fetchall() + with conn.cursor() as cur: + cur.execute("IDENTIFY_SYSTEM") + cur.fetchall() @skip_before_postgres(9, 4) def test_logical_replication_connection(self): conn = self.repl_connect(connection_factory=LogicalReplicationConnection) if conn is None: return - cur = conn.cursor() - cur.execute("IDENTIFY_SYSTEM") - cur.fetchall() + with conn.cursor() as cur: + cur.execute("IDENTIFY_SYSTEM") + cur.fetchall() @skip_before_postgres(9, 4) # slots require 9.4 def test_create_replication_slot(self): conn = self.repl_connect(connection_factory=PhysicalReplicationConnection) if conn is None: return - cur = conn.cursor() - - self.create_replication_slot(cur) - self.assertRaises( - psycopg2.ProgrammingError, self.create_replication_slot, cur) + with conn.cursor() as cur: + self.create_replication_slot(cur) + self.assertRaises( + psycopg2.ProgrammingError, self.create_replication_slot, cur) @skip_before_postgres(9, 4) # slots require 9.4 @skip_repl_if_green @@ -134,13 +132,12 @@ class ReplicationTest(ReplicationTestCase): conn = self.repl_connect(connection_factory=PhysicalReplicationConnection) if conn is None: return - cur = conn.cursor() + with conn.cursor() as cur: + self.assertRaises(psycopg2.ProgrammingError, + cur.start_replication, self.slot) - self.assertRaises(psycopg2.ProgrammingError, - cur.start_replication, self.slot) - - self.create_replication_slot(cur) - cur.start_replication(self.slot) + self.create_replication_slot(cur) + cur.start_replication(self.slot) @skip_before_postgres(9, 4) # slots require 9.4 @skip_repl_if_green @@ -148,12 +145,11 @@ class ReplicationTest(ReplicationTestCase): conn = self.repl_connect(connection_factory=LogicalReplicationConnection) if conn is None: return - cur = conn.cursor() - - self.create_replication_slot(cur, output_plugin='test_decoding') - cur.start_replication_expert( - sql.SQL("START_REPLICATION SLOT {slot} LOGICAL 0/00000000").format( - slot=sql.Identifier(self.slot))) + with conn.cursor() as cur: + self.create_replication_slot(cur, output_plugin='test_decoding') + cur.start_replication_expert( + sql.SQL("START_REPLICATION SLOT {slot} LOGICAL 0/00000000").format( + slot=sql.Identifier(self.slot))) @skip_before_postgres(9, 4) # slots require 9.4 @skip_repl_if_green @@ -161,23 +157,22 @@ class ReplicationTest(ReplicationTestCase): conn = self.repl_connect(connection_factory=LogicalReplicationConnection) if conn is None: return - cur = conn.cursor() + with conn.cursor() as cur: + self.create_replication_slot(cur, output_plugin='test_decoding') + self.make_replication_events() - self.create_replication_slot(cur, output_plugin='test_decoding') - self.make_replication_events() + def consume(msg): + raise StopReplication() - def consume(msg): - raise StopReplication() + with self.assertRaises(psycopg2.DataError): + # try with invalid options + cur.start_replication( + slot_name=self.slot, options={'invalid_param': 'value'}) + cur.consume_stream(consume) - with self.assertRaises(psycopg2.DataError): - # try with invalid options - cur.start_replication( - slot_name=self.slot, options={'invalid_param': 'value'}) - cur.consume_stream(consume) - - # try with correct command - cur.start_replication(slot_name=self.slot) - self.assertRaises(StopReplication, cur.consume_stream, consume) + # try with correct command + cur.start_replication(slot_name=self.slot) + self.assertRaises(StopReplication, cur.consume_stream, consume) @skip_before_postgres(9, 4) # slots require 9.4 @skip_repl_if_green @@ -208,17 +203,16 @@ class ReplicationTest(ReplicationTestCase): conn = self.repl_connect(connection_factory=LogicalReplicationConnection) if conn is None: return - cur = conn.cursor() + with conn.cursor() as cur: + self.create_replication_slot(cur, output_plugin='test_decoding') - self.create_replication_slot(cur, output_plugin='test_decoding') + self.make_replication_events() - self.make_replication_events() + cur.start_replication(self.slot) - cur.start_replication(self.slot) - - def consume(msg): - raise StopReplication() - self.assertRaises(StopReplication, cur.consume_stream, consume) + def consume(msg): + raise StopReplication() + self.assertRaises(StopReplication, cur.consume_stream, consume) class AsyncReplicationTest(ReplicationTestCase): @@ -230,42 +224,41 @@ class AsyncReplicationTest(ReplicationTestCase): if conn is None: return - cur = conn.cursor() + with conn.cursor() as cur: + self.create_replication_slot(cur, output_plugin='test_decoding') + self.wait(cur) - self.create_replication_slot(cur, output_plugin='test_decoding') - self.wait(cur) + cur.start_replication(self.slot) + self.wait(cur) - cur.start_replication(self.slot) - self.wait(cur) + self.make_replication_events() - self.make_replication_events() + self.msg_count = 0 - self.msg_count = 0 + def consume(msg): + # just check the methods + "%s: %s" % (cur.io_timestamp, repr(msg)) + "%s: %s" % (cur.feedback_timestamp, repr(msg)) + "%s: %s" % (cur.wal_end, repr(msg)) - def consume(msg): - # just check the methods - "%s: %s" % (cur.io_timestamp, repr(msg)) - "%s: %s" % (cur.feedback_timestamp, repr(msg)) - "%s: %s" % (cur.wal_end, repr(msg)) + self.msg_count += 1 + if self.msg_count > 3: + cur.send_feedback(reply=True) + raise StopReplication() - self.msg_count += 1 - if self.msg_count > 3: - cur.send_feedback(reply=True) - raise StopReplication() + cur.send_feedback(flush_lsn=msg.data_start) - cur.send_feedback(flush_lsn=msg.data_start) + # cannot be used in asynchronous mode + self.assertRaises(psycopg2.ProgrammingError, cur.consume_stream, consume) - # cannot be used in asynchronous mode - self.assertRaises(psycopg2.ProgrammingError, cur.consume_stream, consume) - - def process_stream(): - while True: - msg = cur.read_message() - if msg: - consume(msg) - else: - select([cur], [], []) - self.assertRaises(StopReplication, process_stream) + def process_stream(): + while True: + msg = cur.read_message() + if msg: + consume(msg) + else: + select([cur], [], []) + self.assertRaises(StopReplication, process_stream) def test_suite(): diff --git a/tests/test_sql.py b/tests/test_sql.py index 9089ae77..53ab952d 100755 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -117,61 +117,61 @@ class SqlFormatTests(ConnectingTestCase): sql.SQL("select {0};").format(sql.Literal(Foo())).as_string, self.conn) def test_execute(self): - cur = self.conn.cursor() - cur.execute(""" - create table test_compose ( - id serial primary key, - foo text, bar text, "ba'z" text) - """) - cur.execute( - sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format( - sql.Identifier('test_compose'), - sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])), - (sql.Placeholder() * 3).join(', ')), - (10, 'a', 'b', 'c')) + with self.conn.cursor() as cur: + cur.execute(""" + create table test_compose ( + id serial primary key, + foo text, bar text, "ba'z" text) + """) + cur.execute( + sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format( + sql.Identifier('test_compose'), + sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])), + (sql.Placeholder() * 3).join(', ')), + (10, 'a', 'b', 'c')) - cur.execute("select * from test_compose") - self.assertEqual(cur.fetchall(), [(10, 'a', 'b', 'c')]) + cur.execute("select * from test_compose") + self.assertEqual(cur.fetchall(), [(10, 'a', 'b', 'c')]) def test_executemany(self): - cur = self.conn.cursor() - cur.execute(""" - create table test_compose ( - id serial primary key, - foo text, bar text, "ba'z" text) - """) - cur.executemany( - sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format( - sql.Identifier('test_compose'), - sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])), - (sql.Placeholder() * 3).join(', ')), - [(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')]) + with self.conn.cursor() as cur: + cur.execute(""" + create table test_compose ( + id serial primary key, + foo text, bar text, "ba'z" text) + """) + cur.executemany( + sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format( + sql.Identifier('test_compose'), + sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])), + (sql.Placeholder() * 3).join(', ')), + [(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')]) - cur.execute("select * from test_compose") - self.assertEqual(cur.fetchall(), - [(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')]) + cur.execute("select * from test_compose") + self.assertEqual(cur.fetchall(), + [(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')]) @skip_copy_if_green @skip_before_postgres(8, 2) def test_copy(self): - cur = self.conn.cursor() - cur.execute(""" - create table test_compose ( - id serial primary key, - foo text, bar text, "ba'z" text) - """) + with self.conn.cursor() as cur: + cur.execute(""" + create table test_compose ( + id serial primary key, + foo text, bar text, "ba'z" text) + """) - s = StringIO("10\ta\tb\tc\n20\td\te\tf\n") - cur.copy_expert( - sql.SQL("copy {t} (id, foo, bar, {f}) from stdin").format( - t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")), s) + s = StringIO("10\ta\tb\tc\n20\td\te\tf\n") + cur.copy_expert( + sql.SQL("copy {t} (id, foo, bar, {f}) from stdin").format( + t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")), s) - s1 = StringIO() - cur.copy_expert( - sql.SQL("copy (select {f} from {t} order by id) to stdout").format( - t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")), s1) - s1.seek(0) - self.assertEqual(s1.read(), 'c\nf\n') + s1 = StringIO() + cur.copy_expert( + sql.SQL("copy (select {f} from {t} order by id) to stdout").format( + t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")), s1) + s1.seek(0) + self.assertEqual(s1.read(), 'c\nf\n') class IdentifierTests(ConnectingTestCase): diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 25feba0b..ef4cb6bc 100755 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -37,59 +37,59 @@ class TransactionTests(ConnectingTestCase): def setUp(self): ConnectingTestCase.setUp(self) self.conn.set_isolation_level(ISOLATION_LEVEL_SERIALIZABLE) - curs = self.conn.cursor() - curs.execute(''' - CREATE TEMPORARY TABLE table1 ( - id int PRIMARY KEY - )''') - # The constraint is set to deferrable for the commit_failed test - curs.execute(''' - CREATE TEMPORARY TABLE table2 ( - id int PRIMARY KEY, - table1_id int, - CONSTRAINT table2__table1_id__fk - FOREIGN KEY (table1_id) REFERENCES table1(id) DEFERRABLE)''') - curs.execute('INSERT INTO table1 VALUES (1)') - curs.execute('INSERT INTO table2 VALUES (1, 1)') + with self.conn.cursor() as curs: + curs.execute(''' + CREATE TEMPORARY TABLE table1 ( + id int PRIMARY KEY + )''') + # The constraint is set to deferrable for the commit_failed test + curs.execute(''' + CREATE TEMPORARY TABLE table2 ( + id int PRIMARY KEY, + table1_id int, + CONSTRAINT table2__table1_id__fk + FOREIGN KEY (table1_id) REFERENCES table1(id) DEFERRABLE)''') + curs.execute('INSERT INTO table1 VALUES (1)') + curs.execute('INSERT INTO table2 VALUES (1, 1)') self.conn.commit() def test_rollback(self): # Test that rollback undoes changes - curs = self.conn.cursor() - curs.execute('INSERT INTO table2 VALUES (2, 1)') - # Rollback takes us from BEGIN state to READY state - self.assertEqual(self.conn.status, STATUS_BEGIN) - self.conn.rollback() - self.assertEqual(self.conn.status, STATUS_READY) - curs.execute('SELECT id, table1_id FROM table2 WHERE id = 2') - self.assertEqual(curs.fetchall(), []) + with self.conn.cursor() as curs: + curs.execute('INSERT INTO table2 VALUES (2, 1)') + # Rollback takes us from BEGIN state to READY state + self.assertEqual(self.conn.status, STATUS_BEGIN) + self.conn.rollback() + self.assertEqual(self.conn.status, STATUS_READY) + curs.execute('SELECT id, table1_id FROM table2 WHERE id = 2') + self.assertEqual(curs.fetchall(), []) def test_commit(self): # Test that commit stores changes - curs = self.conn.cursor() - curs.execute('INSERT INTO table2 VALUES (2, 1)') - # Rollback takes us from BEGIN state to READY state - self.assertEqual(self.conn.status, STATUS_BEGIN) - self.conn.commit() - self.assertEqual(self.conn.status, STATUS_READY) - # Now rollback and show that the new record is still there: - self.conn.rollback() - curs.execute('SELECT id, table1_id FROM table2 WHERE id = 2') - self.assertEqual(curs.fetchall(), [(2, 1)]) + with self.conn.cursor() as curs: + curs.execute('INSERT INTO table2 VALUES (2, 1)') + # Rollback takes us from BEGIN state to READY state + self.assertEqual(self.conn.status, STATUS_BEGIN) + self.conn.commit() + self.assertEqual(self.conn.status, STATUS_READY) + # Now rollback and show that the new record is still there: + self.conn.rollback() + curs.execute('SELECT id, table1_id FROM table2 WHERE id = 2') + self.assertEqual(curs.fetchall(), [(2, 1)]) def test_failed_commit(self): # Test that we can recover from a failed commit. # We use a deferred constraint to cause a failure on commit. - curs = self.conn.cursor() - curs.execute('SET CONSTRAINTS table2__table1_id__fk DEFERRED') - curs.execute('INSERT INTO table2 VALUES (2, 42)') - # The commit should fail, and move the cursor back to READY state - self.assertEqual(self.conn.status, STATUS_BEGIN) - self.assertRaises(psycopg2.IntegrityError, self.conn.commit) - self.assertEqual(self.conn.status, STATUS_READY) - # The connection should be ready to use for the next transaction: - curs.execute('SELECT 1') - self.assertEqual(curs.fetchone()[0], 1) + with self.conn.cursor() as curs: + curs.execute('SET CONSTRAINTS table2__table1_id__fk DEFERRED') + curs.execute('INSERT INTO table2 VALUES (2, 42)') + # The commit should fail, and move the cursor back to READY state + self.assertEqual(self.conn.status, STATUS_BEGIN) + self.assertRaises(psycopg2.IntegrityError, self.conn.commit) + self.assertEqual(self.conn.status, STATUS_READY) + # The connection should be ready to use for the next transaction: + curs.execute('SELECT 1') + self.assertEqual(curs.fetchone()[0], 1) class DeadlockSerializationTests(ConnectingTestCase): @@ -103,32 +103,32 @@ class DeadlockSerializationTests(ConnectingTestCase): def setUp(self): ConnectingTestCase.setUp(self) - curs = self.conn.cursor() - # Drop table if it already exists - try: - curs.execute("DROP TABLE table1") - self.conn.commit() - except psycopg2.DatabaseError: - self.conn.rollback() - try: - curs.execute("DROP TABLE table2") - self.conn.commit() - except psycopg2.DatabaseError: - self.conn.rollback() - # Create sample data - curs.execute(""" - CREATE TABLE table1 ( - id int PRIMARY KEY, - name text) - """) - curs.execute("INSERT INTO table1 VALUES (1, 'hello')") - curs.execute("CREATE TABLE table2 (id int PRIMARY KEY)") + with self.conn.cursor() as curs: + # Drop table if it already exists + try: + curs.execute("DROP TABLE table1") + self.conn.commit() + except psycopg2.DatabaseError: + self.conn.rollback() + try: + curs.execute("DROP TABLE table2") + self.conn.commit() + except psycopg2.DatabaseError: + self.conn.rollback() + # Create sample data + curs.execute(""" + CREATE TABLE table1 ( + id int PRIMARY KEY, + name text) + """) + curs.execute("INSERT INTO table1 VALUES (1, 'hello')") + curs.execute("CREATE TABLE table2 (id int PRIMARY KEY)") self.conn.commit() def tearDown(self): - curs = self.conn.cursor() - curs.execute("DROP TABLE table1") - curs.execute("DROP TABLE table2") + with self.conn.cursor() as curs: + curs.execute("DROP TABLE table1") + curs.execute("DROP TABLE table2") self.conn.commit() ConnectingTestCase.tearDown(self) @@ -142,11 +142,11 @@ class DeadlockSerializationTests(ConnectingTestCase): def task1(): try: conn = self.connect() - curs = conn.cursor() - curs.execute("LOCK table1 IN ACCESS EXCLUSIVE MODE") - step1.set() - step2.wait() - curs.execute("LOCK table2 IN ACCESS EXCLUSIVE MODE") + with conn.cursor() as curs: + curs.execute("LOCK table1 IN ACCESS EXCLUSIVE MODE") + step1.set() + step2.wait() + curs.execute("LOCK table2 IN ACCESS EXCLUSIVE MODE") except psycopg2.DatabaseError as exc: self.thread1_error = exc step1.set() @@ -155,11 +155,11 @@ class DeadlockSerializationTests(ConnectingTestCase): def task2(): try: conn = self.connect() - curs = conn.cursor() - step1.wait() - curs.execute("LOCK table2 IN ACCESS EXCLUSIVE MODE") - step2.set() - curs.execute("LOCK table1 IN ACCESS EXCLUSIVE MODE") + with conn.cursor() as curs: + step1.wait() + curs.execute("LOCK table2 IN ACCESS EXCLUSIVE MODE") + step2.set() + curs.execute("LOCK table1 IN ACCESS EXCLUSIVE MODE") except psycopg2.DatabaseError as exc: self.thread2_error = exc step2.set() @@ -190,12 +190,12 @@ class DeadlockSerializationTests(ConnectingTestCase): def task1(): try: conn = self.connect() - curs = conn.cursor() - curs.execute("SELECT name FROM table1 WHERE id = 1") - curs.fetchall() - step1.set() - step2.wait() - curs.execute("UPDATE table1 SET name='task1' WHERE id = 1") + with conn.cursor() as curs: + curs.execute("SELECT name FROM table1 WHERE id = 1") + curs.fetchall() + step1.set() + step2.wait() + curs.execute("UPDATE table1 SET name='task1' WHERE id = 1") conn.commit() except psycopg2.DatabaseError as exc: self.thread1_error = exc @@ -205,9 +205,9 @@ class DeadlockSerializationTests(ConnectingTestCase): def task2(): try: conn = self.connect() - curs = conn.cursor() - step1.wait() - curs.execute("UPDATE table1 SET name='task2' WHERE id = 1") + with conn.cursor() as curs: + step1.wait() + curs.execute("UPDATE table1 SET name='task2' WHERE id = 1") conn.commit() except psycopg2.DatabaseError as exc: self.thread2_error = exc @@ -240,11 +240,11 @@ class QueryCancellationTests(ConnectingTestCase): @skip_before_postgres(8, 2) def test_statement_timeout(self): - curs = self.conn.cursor() - # Set a low statement timeout, then sleep for a longer period. - curs.execute('SET statement_timeout TO 10') - self.assertRaises(psycopg2.extensions.QueryCanceledError, - curs.execute, 'SELECT pg_sleep(50)') + with self.conn.cursor() as curs: + # Set a low statement timeout, then sleep for a longer period. + curs.execute('SET statement_timeout TO 10') + self.assertRaises(psycopg2.extensions.QueryCanceledError, + curs.execute, 'SELECT pg_sleep(50)') def test_suite(): diff --git a/tests/test_types_basic.py b/tests/test_types_basic.py index a6b7af03..84d22597 100755 --- a/tests/test_types_basic.py +++ b/tests/test_types_basic.py @@ -41,9 +41,9 @@ class TypesBasicTests(ConnectingTestCase): """Test that all type conversions are working.""" def execute(self, *args): - curs = self.conn.cursor() - curs.execute(*args) - return curs.fetchone()[0] + with self.conn.cursor() as curs: + curs.execute(*args) + return curs.fetchone()[0] def testQuoting(self): s = "Quote'this\\! ''ok?''" @@ -156,26 +156,27 @@ class TypesBasicTests(ConnectingTestCase): def testEmptyArrayRegression(self): # ticket #42 - curs = self.conn.cursor() - curs.execute( - "create table array_test " - "(id integer, col timestamp without time zone[])") + with self.conn.cursor() as curs: + curs.execute( + "create table array_test " + "(id integer, col timestamp without time zone[])") - curs.execute("insert into array_test values (%s, %s)", - (1, [datetime.date(2011, 2, 14)])) - curs.execute("select col from array_test where id = 1") - self.assertEqual(curs.fetchone()[0], [datetime.datetime(2011, 2, 14, 0, 0)]) + curs.execute("insert into array_test values (%s, %s)", + (1, [datetime.date(2011, 2, 14)])) + curs.execute("select col from array_test where id = 1") + self.assertEqual( + curs.fetchone()[0], [datetime.datetime(2011, 2, 14, 0, 0)]) - curs.execute("insert into array_test values (%s, %s)", (2, [])) - curs.execute("select col from array_test where id = 2") - self.assertEqual(curs.fetchone()[0], []) + curs.execute("insert into array_test values (%s, %s)", (2, [])) + curs.execute("select col from array_test where id = 2") + self.assertEqual(curs.fetchone()[0], []) @testutils.skip_before_postgres(8, 4) def testNestedEmptyArray(self): # issue #788 - curs = self.conn.cursor() - curs.execute("select 10 = any(%s::int[])", ([[]], )) - self.assertFalse(curs.fetchone()[0]) + with self.conn.cursor() as curs: + curs.execute("select 10 = any(%s::int[])", ([[]], )) + self.assertFalse(curs.fetchone()[0]) def testEmptyArrayNoCast(self): s = self.execute("SELECT '{}' AS foo") @@ -204,86 +205,86 @@ class TypesBasicTests(ConnectingTestCase): self.failUnlessEqual(ss, r) def testArrayMalformed(self): - curs = self.conn.cursor() - ss = ['', '{', '{}}', '{' * 20 + '}' * 20] - for s in ss: - self.assertRaises(psycopg2.DataError, - psycopg2.extensions.STRINGARRAY, s.encode('utf8'), curs) + with self.conn.cursor() as curs: + ss = ['', '{', '{}}', '{' * 20 + '}' * 20] + for s in ss: + self.assertRaises(psycopg2.DataError, + psycopg2.extensions.STRINGARRAY, s.encode('utf8'), curs) def testTextArray(self): - curs = self.conn.cursor() - curs.execute("select '{a,b,c}'::text[]") - x = curs.fetchone()[0] + with self.conn.cursor() as curs: + curs.execute("select '{a,b,c}'::text[]") + x = curs.fetchone()[0] self.assert_(isinstance(x[0], str)) self.assertEqual(x, ['a', 'b', 'c']) def testUnicodeArray(self): psycopg2.extensions.register_type( psycopg2.extensions.UNICODEARRAY, self.conn) - curs = self.conn.cursor() - curs.execute("select '{a,b,c}'::text[]") - x = curs.fetchone()[0] + with self.conn.cursor() as curs: + curs.execute("select '{a,b,c}'::text[]") + x = curs.fetchone()[0] self.assert_(isinstance(x[0], text_type)) self.assertEqual(x, [u'a', u'b', u'c']) def testBytesArray(self): psycopg2.extensions.register_type( psycopg2.extensions.BYTESARRAY, self.conn) - curs = self.conn.cursor() - curs.execute("select '{a,b,c}'::text[]") - x = curs.fetchone()[0] + with self.conn.cursor() as curs: + curs.execute("select '{a,b,c}'::text[]") + x = curs.fetchone()[0] self.assert_(isinstance(x[0], bytes)) self.assertEqual(x, [b'a', b'b', b'c']) @testutils.skip_before_postgres(8, 2) def testArrayOfNulls(self): - curs = self.conn.cursor() - curs.execute(""" - create table na ( - texta text[], - inta int[], - boola boolean[], + with self.conn.cursor() as curs: + curs.execute(""" + create table na ( + texta text[], + inta int[], + boola boolean[], - textaa text[][], - intaa int[][], - boolaa boolean[][] - )""") + textaa text[][], + intaa int[][], + boolaa boolean[][] + )""") - curs.execute("insert into na (texta) values (%s)", ([None],)) - curs.execute("insert into na (texta) values (%s)", (['a', None],)) - curs.execute("insert into na (texta) values (%s)", ([None, None],)) - curs.execute("insert into na (inta) values (%s)", ([None],)) - curs.execute("insert into na (inta) values (%s)", ([42, None],)) - curs.execute("insert into na (inta) values (%s)", ([None, None],)) - curs.execute("insert into na (boola) values (%s)", ([None],)) - curs.execute("insert into na (boola) values (%s)", ([True, None],)) - curs.execute("insert into na (boola) values (%s)", ([None, None],)) + curs.execute("insert into na (texta) values (%s)", ([None],)) + curs.execute("insert into na (texta) values (%s)", (['a', None],)) + curs.execute("insert into na (texta) values (%s)", ([None, None],)) + curs.execute("insert into na (inta) values (%s)", ([None],)) + curs.execute("insert into na (inta) values (%s)", ([42, None],)) + curs.execute("insert into na (inta) values (%s)", ([None, None],)) + curs.execute("insert into na (boola) values (%s)", ([None],)) + curs.execute("insert into na (boola) values (%s)", ([True, None],)) + curs.execute("insert into na (boola) values (%s)", ([None, None],)) - curs.execute("insert into na (textaa) values (%s)", ([[None]],)) - curs.execute("insert into na (textaa) values (%s)", ([['a', None]],)) - curs.execute("insert into na (textaa) values (%s)", ([[None, None]],)) + curs.execute("insert into na (textaa) values (%s)", ([[None]],)) + curs.execute("insert into na (textaa) values (%s)", ([['a', None]],)) + curs.execute("insert into na (textaa) values (%s)", ([[None, None]],)) - curs.execute("insert into na (intaa) values (%s)", ([[None]],)) - curs.execute("insert into na (intaa) values (%s)", ([[42, None]],)) - curs.execute("insert into na (intaa) values (%s)", ([[None, None]],)) + curs.execute("insert into na (intaa) values (%s)", ([[None]],)) + curs.execute("insert into na (intaa) values (%s)", ([[42, None]],)) + curs.execute("insert into na (intaa) values (%s)", ([[None, None]],)) - curs.execute("insert into na (boolaa) values (%s)", ([[None]],)) - curs.execute("insert into na (boolaa) values (%s)", ([[True, None]],)) - curs.execute("insert into na (boolaa) values (%s)", ([[None, None]],)) + curs.execute("insert into na (boolaa) values (%s)", ([[None]],)) + curs.execute("insert into na (boolaa) values (%s)", ([[True, None]],)) + curs.execute("insert into na (boolaa) values (%s)", ([[None, None]],)) @testutils.skip_before_postgres(8, 2) def testNestedArrays(self): - curs = self.conn.cursor() - for a in [ - [[1]], - [[None]], - [[None, None, None]], - [[None, None], [1, None]], - [[None, None], [None, None]], - [[[None, None], [None, None]]], - ]: - curs.execute("select %s::int[]", (a,)) - self.assertEqual(curs.fetchone()[0], a) + with self.conn.cursor() as curs: + for a in [ + [[1]], + [[None]], + [[None, None, None]], + [[None, None], [1, None]], + [[None, None], [None, None]], + [[[None, None], [None, None]]], + ]: + curs.execute("select %s::int[]", (a,)) + self.assertEqual(curs.fetchone()[0], a) @testutils.skip_from_python(3) def testTypeRoundtripBuffer(self): diff --git a/tests/test_types_extras.py b/tests/test_types_extras.py index e4828077..081d12f2 100755 --- a/tests/test_types_extras.py +++ b/tests/test_types_extras.py @@ -45,9 +45,9 @@ class TypesExtrasTests(ConnectingTestCase): """Test that all type conversions are working.""" def execute(self, *args): - curs = self.conn.cursor() - curs.execute(*args) - return curs.fetchone()[0] + with self.conn.cursor() as curs: + curs.execute(*args) + return curs.fetchone()[0] @skip_if_no_uuid def testUUID(self): @@ -231,19 +231,19 @@ class HstoreTestCase(ConnectingTestCase): @skip_if_no_hstore def test_register_conn(self): register_hstore(self.conn) - cur = self.conn.cursor() - cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") - t = cur.fetchone() + with self.conn.cursor() as cur: + cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") + t = cur.fetchone() self.assert_(t[0] is None) self.assertEqual(t[1], {}) self.assertEqual(t[2], {'a': 'b'}) @skip_if_no_hstore def test_register_curs(self): - cur = self.conn.cursor() - register_hstore(cur) - cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") - t = cur.fetchone() + with self.conn.cursor() as cur: + register_hstore(cur) + cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") + t = cur.fetchone() self.assert_(t[0] is None) self.assertEqual(t[1], {}) self.assertEqual(t[2], {'a': 'b'}) @@ -252,9 +252,9 @@ class HstoreTestCase(ConnectingTestCase): @skip_from_python(3) def test_register_unicode(self): register_hstore(self.conn, unicode=True) - cur = self.conn.cursor() - cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") - t = cur.fetchone() + with self.conn.cursor() as cur: + cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") + t = cur.fetchone() self.assert_(t[0] is None) self.assertEqual(t[1], {}) self.assertEqual(t[2], {u'a': u'b'}) @@ -268,9 +268,9 @@ class HstoreTestCase(ConnectingTestCase): register_hstore(self.conn, globally=True) conn2 = self.connect() try: - cur2 = self.conn.cursor() - cur2.execute("select 'a => b'::hstore") - r = cur2.fetchone() + with self.conn.cursor() as cur2: + cur2.execute("select 'a => b'::hstore") + r = cur2.fetchone() self.assert_(isinstance(r[0], dict)) finally: conn2.close() @@ -278,67 +278,66 @@ class HstoreTestCase(ConnectingTestCase): @skip_if_no_hstore def test_roundtrip(self): register_hstore(self.conn) - cur = self.conn.cursor() + with self.conn.cursor() as cur: + def ok(d): + cur.execute("select %s", (d,)) + d1 = cur.fetchone()[0] + self.assertEqual(len(d), len(d1)) + for k in d: + self.assert_(k in d1, k) + self.assertEqual(d[k], d1[k]) - def ok(d): - cur.execute("select %s", (d,)) - d1 = cur.fetchone()[0] - self.assertEqual(len(d), len(d1)) - for k in d: - self.assert_(k in d1, k) - self.assertEqual(d[k], d1[k]) + ok({}) + ok({'a': 'b', 'c': None}) - ok({}) - ok({'a': 'b', 'c': None}) + ab = list(map(chr, range(32, 128))) + ok(dict(zip(ab, ab))) + ok({''.join(ab): ''.join(ab)}) - ab = list(map(chr, range(32, 128))) - ok(dict(zip(ab, ab))) - ok({''.join(ab): ''.join(ab)}) + self.conn.set_client_encoding('latin1') + if PY2: + ab = map(chr, range(32, 127) + range(160, 255)) + else: + ab = bytes( + list(range(32, 127)) + list(range(160, 255))).decode('latin1') - self.conn.set_client_encoding('latin1') - if PY2: - ab = map(chr, range(32, 127) + range(160, 255)) - else: - ab = bytes(list(range(32, 127)) + list(range(160, 255))).decode('latin1') - - ok({''.join(ab): ''.join(ab)}) - ok(dict(zip(ab, ab))) + ok({''.join(ab): ''.join(ab)}) + ok(dict(zip(ab, ab))) @skip_if_no_hstore @skip_from_python(3) def test_roundtrip_unicode(self): register_hstore(self.conn, unicode=True) - cur = self.conn.cursor() + with self.conn.cursor() as cur: + def ok(d): + cur.execute("select %s", (d,)) + d1 = cur.fetchone()[0] + self.assertEqual(len(d), len(d1)) + for k, v in d1.iteritems(): + self.assert_(k in d, k) + self.assertEqual(d[k], v) + self.assert_(isinstance(k, unicode)) + self.assert_(v is None or isinstance(v, unicode)) - def ok(d): - cur.execute("select %s", (d,)) - d1 = cur.fetchone()[0] - self.assertEqual(len(d), len(d1)) - for k, v in d1.iteritems(): - self.assert_(k in d, k) - self.assertEqual(d[k], v) - self.assert_(isinstance(k, unicode)) - self.assert_(v is None or isinstance(v, unicode)) + ok({}) + ok({'a': 'b', 'c': None, 'd': u'\u20ac', u'\u2603': 'e'}) - ok({}) - ok({'a': 'b', 'c': None, 'd': u'\u20ac', u'\u2603': 'e'}) - - ab = map(unichr, range(1, 1024)) - ok({u''.join(ab): u''.join(ab)}) - ok(dict(zip(ab, ab))) + ab = map(unichr, range(1, 1024)) + ok({u''.join(ab): u''.join(ab)}) + ok(dict(zip(ab, ab))) @skip_if_no_hstore @restore_types def test_oid(self): - cur = self.conn.cursor() - cur.execute("select 'hstore'::regtype::oid") - oid = cur.fetchone()[0] + with self.conn.cursor() as cur: + cur.execute("select 'hstore'::regtype::oid") + oid = cur.fetchone()[0] - # Note: None as conn_or_cursor is just for testing: not public - # interface and it may break in future. - register_hstore(None, globally=True, oid=oid) - cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") - t = cur.fetchone() + # Note: None as conn_or_cursor is just for testing: not public + # interface and it may break in future. + register_hstore(None, globally=True, oid=oid) + cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") + t = cur.fetchone() self.assert_(t[0] is None) self.assertEqual(t[1], {}) self.assertEqual(t[2], {'a': 'b'}) @@ -363,32 +362,32 @@ class HstoreTestCase(ConnectingTestCase): ds.append({''.join(ab): ''.join(ab)}) ds.append(dict(zip(ab, ab))) - cur = self.conn.cursor() - cur.execute("select %s", (ds,)) - ds1 = cur.fetchone()[0] + with self.conn.cursor() as cur: + cur.execute("select %s", (ds,)) + ds1 = cur.fetchone()[0] self.assertEqual(ds, ds1) @skip_if_no_hstore @skip_before_postgres(8, 3) def test_array_cast(self): register_hstore(self.conn) - cur = self.conn.cursor() - cur.execute("select array['a=>1'::hstore, 'b=>2'::hstore];") - a = cur.fetchone()[0] + with self.conn.cursor() as cur: + cur.execute("select array['a=>1'::hstore, 'b=>2'::hstore];") + a = cur.fetchone()[0] self.assertEqual(a, [{'a': '1'}, {'b': '2'}]) @skip_if_no_hstore @restore_types def test_array_cast_oid(self): - cur = self.conn.cursor() - cur.execute("select 'hstore'::regtype::oid, 'hstore[]'::regtype::oid") - oid, aoid = cur.fetchone() + with self.conn.cursor() as cur: + cur.execute("select 'hstore'::regtype::oid, 'hstore[]'::regtype::oid") + oid, aoid = cur.fetchone() - register_hstore(None, globally=True, oid=oid, array_oid=aoid) - cur.execute(""" - select null::hstore, ''::hstore, - 'a => b'::hstore, '{a=>b}'::hstore[]""") - t = cur.fetchone() + register_hstore(None, globally=True, oid=oid, array_oid=aoid) + cur.execute(""" + select null::hstore, ''::hstore, + 'a => b'::hstore, '{a=>b}'::hstore[]""") + t = cur.fetchone() self.assert_(t[0] is None) self.assertEqual(t[1], {}) self.assertEqual(t[2], {'a': 'b'}) @@ -399,18 +398,18 @@ class HstoreTestCase(ConnectingTestCase): conn = self.connect(connection_factory=RealDictConnection) try: register_hstore(conn) - curs = conn.cursor() - curs.execute("select ''::hstore as x") - self.assertEqual(curs.fetchone()['x'], {}) + with conn.cursor() as curs: + curs.execute("select ''::hstore as x") + self.assertEqual(curs.fetchone()['x'], {}) finally: conn.close() conn = self.connect(connection_factory=RealDictConnection) try: - curs = conn.cursor() - register_hstore(curs) - curs.execute("select ''::hstore as x") - self.assertEqual(curs.fetchone()['x'], {}) + with conn.cursor() as curs: + register_hstore(curs) + curs.execute("select ''::hstore as x") + self.assertEqual(curs.fetchone()['x'], {}) finally: conn.close() @@ -431,12 +430,12 @@ def skip_if_no_composite(f): class AdaptTypeTestCase(ConnectingTestCase): @skip_if_no_composite def test_none_in_record(self): - curs = self.conn.cursor() - s = curs.mogrify("SELECT %s;", [(42, None)]) - self.assertEqual(b"SELECT (42, NULL);", s) - curs.execute("SELECT %s;", [(42, None)]) - d = curs.fetchone()[0] - self.assertEqual("(42,)", d) + with self.conn.cursor() as curs: + s = curs.mogrify("SELECT %s;", [(42, None)]) + self.assertEqual(b"SELECT (42, NULL);", s) + curs.execute("SELECT %s;", [(42, None)]) + d = curs.fetchone()[0] + self.assertEqual("(42,)", d) def test_none_fast_path(self): # the None adapter is not actually invoked in regular adaptation @@ -448,18 +447,17 @@ class AdaptTypeTestCase(ConnectingTestCase): def getquoted(self): return "NOPE!" - curs = self.conn.cursor() + with self.conn.cursor() as curs: + orig_adapter = ext.adapters[type(None), ext.ISQLQuote] + try: + ext.register_adapter(type(None), WonkyAdapter) + self.assertEqual(ext.adapt(None).getquoted(), "NOPE!") - orig_adapter = ext.adapters[type(None), ext.ISQLQuote] - try: - ext.register_adapter(type(None), WonkyAdapter) - self.assertEqual(ext.adapt(None).getquoted(), "NOPE!") + s = curs.mogrify("SELECT %s;", (None,)) + self.assertEqual(b"SELECT NULL;", s) - s = curs.mogrify("SELECT %s;", (None,)) - self.assertEqual(b"SELECT NULL;", s) - - finally: - ext.register_adapter(type(None), orig_adapter) + finally: + ext.register_adapter(type(None), orig_adapter) def test_tokenization(self): def ok(s, v): @@ -502,10 +500,10 @@ class AdaptTypeTestCase(ConnectingTestCase): self.assertEqual(t.attnames, ['anint', 'astring', 'adate']) self.assertEqual(t.atttypes, [23, 25, 1082]) - curs = self.conn.cursor() - r = (10, 'hello', date(2011, 1, 2)) - curs.execute("select %s::type_isd;", (r,)) - v = curs.fetchone()[0] + with self.conn.cursor() as curs: + r = (10, 'hello', date(2011, 1, 2)) + curs.execute("select %s::type_isd;", (r,)) + v = curs.fetchone()[0] self.assert_(isinstance(v, t.type)) self.assertEqual(v[0], 10) self.assertEqual(v[1], "hello") @@ -519,21 +517,21 @@ class AdaptTypeTestCase(ConnectingTestCase): def test_empty_string(self): # issue #141 self._create_type("type_ss", [('s1', 'text'), ('s2', 'text')]) - curs = self.conn.cursor() - psycopg2.extras.register_composite("type_ss", curs) + with self.conn.cursor() as curs: + psycopg2.extras.register_composite("type_ss", curs) - def ok(t): - curs.execute("select %s::type_ss", (t,)) - rv = curs.fetchone()[0] - self.assertEqual(t, rv) + def ok(t): + curs.execute("select %s::type_ss", (t,)) + rv = curs.fetchone()[0] + self.assertEqual(t, rv) - ok(('a', 'b')) - ok(('a', '')) - ok(('', 'b')) - ok(('a', None)) - ok((None, 'b')) - ok(('', '')) - ok((None, None)) + ok(('a', 'b')) + ok(('a', '')) + ok(('', 'b')) + ok(('a', None)) + ok((None, 'b')) + ok(('', '')) + ok((None, None)) @skip_if_no_composite def test_cast_nested(self): @@ -548,10 +546,10 @@ class AdaptTypeTestCase(ConnectingTestCase): psycopg2.extras.register_composite("type_r_dt", self.conn) psycopg2.extras.register_composite("type_r_ft", self.conn) - curs = self.conn.cursor() - r = (0.25, (date(2011, 1, 2), (42, "hello"))) - curs.execute("select %s::type_r_ft;", (r,)) - v = curs.fetchone()[0] + with self.conn.cursor() as curs: + r = (0.25, (date(2011, 1, 2), (42, "hello"))) + curs.execute("select %s::type_r_ft;", (r,)) + v = curs.fetchone()[0] self.assertEqual(r, v) self.assertEqual(v.anotherpair.apair.astring, "hello") @@ -560,13 +558,12 @@ class AdaptTypeTestCase(ConnectingTestCase): def test_register_on_cursor(self): self._create_type("type_ii", [("a", "integer"), ("b", "integer")]) - curs1 = self.conn.cursor() - curs2 = self.conn.cursor() - psycopg2.extras.register_composite("type_ii", curs1) - curs1.execute("select (1,2)::type_ii") - self.assertEqual(curs1.fetchone()[0], (1, 2)) - curs2.execute("select (1,2)::type_ii") - self.assertEqual(curs2.fetchone()[0], "(1,2)") + with self.conn.cursor() as curs1, self.conn.cursor() as curs2: + psycopg2.extras.register_composite("type_ii", curs1) + curs1.execute("select (1,2)::type_ii") + self.assertEqual(curs1.fetchone()[0], (1, 2)) + curs2.execute("select (1,2)::type_ii") + self.assertEqual(curs2.fetchone()[0], "(1,2)") @skip_if_no_composite def test_register_on_connection(self): @@ -576,12 +573,11 @@ class AdaptTypeTestCase(ConnectingTestCase): conn2 = self.connect() try: psycopg2.extras.register_composite("type_ii", conn1) - curs1 = conn1.cursor() - curs2 = conn2.cursor() - curs1.execute("select (1,2)::type_ii") - self.assertEqual(curs1.fetchone()[0], (1, 2)) - curs2.execute("select (1,2)::type_ii") - self.assertEqual(curs2.fetchone()[0], "(1,2)") + with conn1.cursor() as curs1, conn2.cursor() as curs2: + curs1.execute("select (1,2)::type_ii") + self.assertEqual(curs1.fetchone()[0], (1, 2)) + curs2.execute("select (1,2)::type_ii") + self.assertEqual(curs2.fetchone()[0], "(1,2)") finally: conn1.close() conn2.close() @@ -595,12 +591,11 @@ class AdaptTypeTestCase(ConnectingTestCase): conn2 = self.connect() try: psycopg2.extras.register_composite("type_ii", conn1, globally=True) - curs1 = conn1.cursor() - curs2 = conn2.cursor() - curs1.execute("select (1,2)::type_ii") - self.assertEqual(curs1.fetchone()[0], (1, 2)) - curs2.execute("select (1,2)::type_ii") - self.assertEqual(curs2.fetchone()[0], (1, 2)) + with conn1.cursor() as curs1, conn2.cursor() as curs2: + curs1.execute("select (1,2)::type_ii") + self.assertEqual(curs1.fetchone()[0], (1, 2)) + curs2.execute("select (1,2)::type_ii") + self.assertEqual(curs2.fetchone()[0], (1, 2)) finally: conn1.close() @@ -608,22 +603,22 @@ class AdaptTypeTestCase(ConnectingTestCase): @skip_if_no_composite def test_composite_namespace(self): - curs = self.conn.cursor() - curs.execute(""" - select nspname from pg_namespace - where nspname = 'typens'; - """) - if not curs.fetchone(): - curs.execute("create schema typens;") - self.conn.commit() + with self.conn.cursor() as curs: + curs.execute(""" + select nspname from pg_namespace + where nspname = 'typens'; + """) + if not curs.fetchone(): + curs.execute("create schema typens;") + self.conn.commit() - self._create_type("typens.typens_ii", - [("a", "integer"), ("b", "integer")]) - t = psycopg2.extras.register_composite( - "typens.typens_ii", self.conn) - self.assertEqual(t.schema, 'typens') - curs.execute("select (4,8)::typens.typens_ii") - self.assertEqual(curs.fetchone()[0], (4, 8)) + self._create_type("typens.typens_ii", + [("a", "integer"), ("b", "integer")]) + t = psycopg2.extras.register_composite( + "typens.typens_ii", self.conn) + self.assertEqual(t.schema, 'typens') + curs.execute("select (4,8)::typens.typens_ii") + self.assertEqual(curs.fetchone()[0], (4, 8)) @skip_if_no_composite @skip_before_postgres(8, 4) @@ -633,11 +628,11 @@ class AdaptTypeTestCase(ConnectingTestCase): t = psycopg2.extras.register_composite("type_isd", self.conn) - curs = self.conn.cursor() - r1 = (10, 'hello', date(2011, 1, 2)) - r2 = (20, 'world', date(2011, 1, 3)) - curs.execute("select %s::type_isd[];", ([r1, r2],)) - v = curs.fetchone()[0] + with self.conn.cursor() as curs: + r1 = (10, 'hello', date(2011, 1, 2)) + r2 = (20, 'world', date(2011, 1, 3)) + curs.execute("select %s::type_isd[];", ([r1, r2],)) + v = curs.fetchone()[0] self.assertEqual(len(v), 2) self.assert_(isinstance(v[0], t.type)) self.assertEqual(v[0][0], 10) @@ -652,56 +647,56 @@ class AdaptTypeTestCase(ConnectingTestCase): def test_wrong_schema(self): oid = self._create_type("type_ii", [("a", "integer"), ("b", "integer")]) c = CompositeCaster('type_ii', oid, [('a', 23), ('b', 23), ('c', 23)]) - curs = self.conn.cursor() - psycopg2.extensions.register_type(c.typecaster, curs) - curs.execute("select (1,2)::type_ii") - self.assertRaises(psycopg2.DataError, curs.fetchone) + with self.conn.cursor() as curs: + psycopg2.extensions.register_type(c.typecaster, curs) + curs.execute("select (1,2)::type_ii") + self.assertRaises(psycopg2.DataError, curs.fetchone) @slow @skip_if_no_composite @skip_before_postgres(8, 4) def test_from_tables(self): - curs = self.conn.cursor() - curs.execute("""create table ctest1 ( - id integer primary key, - temp int, - label varchar - );""") + with self.conn.cursor() as curs: + curs.execute("""create table ctest1 ( + id integer primary key, + temp int, + label varchar + );""") - curs.execute("""alter table ctest1 drop temp;""") + curs.execute("""alter table ctest1 drop temp;""") - curs.execute("""create table ctest2 ( - id serial primary key, - label varchar, - test_id integer references ctest1(id) - );""") + curs.execute("""create table ctest2 ( + id serial primary key, + label varchar, + test_id integer references ctest1(id) + );""") - curs.execute("""insert into ctest1 (id, label) values - (1, 'test1'), - (2, 'test2');""") - curs.execute("""insert into ctest2 (label, test_id) values - ('testa', 1), - ('testb', 1), - ('testc', 2), - ('testd', 2);""") + curs.execute("""insert into ctest1 (id, label) values + (1, 'test1'), + (2, 'test2');""") + curs.execute("""insert into ctest2 (label, test_id) values + ('testa', 1), + ('testb', 1), + ('testc', 2), + ('testd', 2);""") - psycopg2.extras.register_composite("ctest1", curs) - psycopg2.extras.register_composite("ctest2", curs) + psycopg2.extras.register_composite("ctest1", curs) + psycopg2.extras.register_composite("ctest2", curs) - curs.execute(""" - select ctest1, array_agg(ctest2) as test2s - from ( - select ctest1, ctest2 - from ctest1 inner join ctest2 on ctest1.id = ctest2.test_id - order by ctest1.id, ctest2.label - ) x group by ctest1;""") + curs.execute(""" + select ctest1, array_agg(ctest2) as test2s + from ( + select ctest1, ctest2 + from ctest1 inner join ctest2 on ctest1.id = ctest2.test_id + order by ctest1.id, ctest2.label + ) x group by ctest1;""") - r = curs.fetchone() - self.assertEqual(r[0], (1, 'test1')) - self.assertEqual(r[1], [(1, 'testa', 1), (2, 'testb', 1)]) - r = curs.fetchone() - self.assertEqual(r[0], (2, 'test2')) - self.assertEqual(r[1], [(3, 'testc', 2), (4, 'testd', 2)]) + r = curs.fetchone() + self.assertEqual(r[0], (1, 'test1')) + self.assertEqual(r[1], [(1, 'testa', 1), (2, 'testb', 1)]) + r = curs.fetchone() + self.assertEqual(r[0], (2, 'test2')) + self.assertEqual(r[1], [(3, 'testc', 2), (4, 'testd', 2)]) @skip_if_no_composite def test_non_dbapi_connection(self): @@ -710,18 +705,18 @@ class AdaptTypeTestCase(ConnectingTestCase): conn = self.connect(connection_factory=RealDictConnection) try: register_composite('type_ii', conn) - curs = conn.cursor() - curs.execute("select '(1,2)'::type_ii as x") - self.assertEqual(curs.fetchone()['x'], (1, 2)) + with conn.cursor() as curs: + curs.execute("select '(1,2)'::type_ii as x") + self.assertEqual(curs.fetchone()['x'], (1, 2)) finally: conn.close() conn = self.connect(connection_factory=RealDictConnection) try: - curs = conn.cursor() - register_composite('type_ii', conn) - curs.execute("select '(1,2)'::type_ii as x") - self.assertEqual(curs.fetchone()['x'], (1, 2)) + with conn.cursor() as curs: + register_composite('type_ii', conn) + curs.execute("select '(1,2)'::type_ii as x") + self.assertEqual(curs.fetchone()['x'], (1, 2)) finally: conn.close() @@ -739,35 +734,35 @@ class AdaptTypeTestCase(ConnectingTestCase): self.assertEqual(t.name, 'type_isd') self.assertEqual(t.oid, oid) - curs = self.conn.cursor() - r = (10, 'hello', date(2011, 1, 2)) - curs.execute("select %s::type_isd;", (r,)) - v = curs.fetchone()[0] + with self.conn.cursor() as curs: + r = (10, 'hello', date(2011, 1, 2)) + curs.execute("select %s::type_isd;", (r,)) + v = curs.fetchone()[0] self.assert_(isinstance(v, dict)) self.assertEqual(v['anint'], 10) self.assertEqual(v['astring'], "hello") self.assertEqual(v['adate'], date(2011, 1, 2)) def _create_type(self, name, fields): - curs = self.conn.cursor() - try: - curs.execute("drop type %s cascade;" % name) - except psycopg2.ProgrammingError: - self.conn.rollback() + with self.conn.cursor() as curs: + try: + curs.execute("drop type %s cascade;" % name) + except psycopg2.ProgrammingError: + self.conn.rollback() - curs.execute("create type %s as (%s);" % (name, - ", ".join(["%s %s" % p for p in fields]))) - if '.' in name: - schema, name = name.split('.') - else: - schema = 'public' + curs.execute("create type %s as (%s);" % (name, + ", ".join(["%s %s" % p for p in fields]))) + if '.' in name: + schema, name = name.split('.') + else: + schema = 'public' - curs.execute("""\ - SELECT t.oid - FROM pg_type t JOIN pg_namespace ns ON typnamespace = ns.oid - WHERE typname = %s and nspname = %s; - """, (name, schema)) - oid = curs.fetchone()[0] + curs.execute("""\ + SELECT t.oid + FROM pg_type t JOIN pg_namespace ns ON typnamespace = ns.oid + WHERE typname = %s and nspname = %s; + """, (name, schema)) + oid = curs.fetchone()[0] self.conn.commit() return oid @@ -776,10 +771,10 @@ def skip_if_no_json_type(f): """Skip a test if PostgreSQL json type is not available""" @wraps(f) def skip_if_no_json_type_(self): - curs = self.conn.cursor() - curs.execute("select oid from pg_type where typname = 'json'") - if not curs.fetchone(): - return self.skipTest("json not available in test database") + with self.conn.cursor() as curs: + curs.execute("select oid from pg_type where typname = 'json'") + if not curs.fetchone(): + return self.skipTest("json not available in test database") return f(self) @@ -791,10 +786,10 @@ class JsonTestCase(ConnectingTestCase): objs = [None, "te'xt", 123, 123.45, u'\xe0\u20ac', ['a', 100], {'a': 100}] - curs = self.conn.cursor() - for obj in enumerate(objs): - self.assertQuotedEqual(curs.mogrify("%s", (Json(obj),)), - psycopg2.extensions.QuotedString(json.dumps(obj)).getquoted()) + with self.conn.cursor() as curs: + for obj in enumerate(objs): + self.assertQuotedEqual(curs.mogrify("%s", (Json(obj),)), + psycopg2.extensions.QuotedString(json.dumps(obj)).getquoted()) def test_adapt_dumps(self): class DecimalEncoder(json.JSONEncoder): @@ -803,13 +798,13 @@ class JsonTestCase(ConnectingTestCase): return float(obj) return json.JSONEncoder.default(self, obj) - curs = self.conn.cursor() - obj = Decimal('123.45') + with self.conn.cursor() as curs: + obj = Decimal('123.45') - def dumps(obj): - return json.dumps(obj, cls=DecimalEncoder) - self.assertQuotedEqual(curs.mogrify("%s", (Json(obj, dumps=dumps),)), - b"'123.45'") + def dumps(obj): + return json.dumps(obj, cls=DecimalEncoder) + self.assertQuotedEqual(curs.mogrify("%s", (Json(obj, dumps=dumps),)), + b"'123.45'") def test_adapt_subclass(self): class DecimalEncoder(json.JSONEncoder): @@ -822,59 +817,58 @@ class JsonTestCase(ConnectingTestCase): def dumps(self, obj): return json.dumps(obj, cls=DecimalEncoder) - curs = self.conn.cursor() - obj = Decimal('123.45') - self.assertQuotedEqual(curs.mogrify("%s", (MyJson(obj),)), b"'123.45'") + with self.conn.cursor() as curs: + obj = Decimal('123.45') + self.assertQuotedEqual(curs.mogrify("%s", (MyJson(obj),)), b"'123.45'") @restore_types def test_register_on_dict(self): psycopg2.extensions.register_adapter(dict, Json) - curs = self.conn.cursor() - obj = {'a': 123} - self.assertQuotedEqual( - curs.mogrify("%s", (obj,)), b"""'{"a": 123}'""") + with self.conn.cursor() as curs: + obj = {'a': 123} + self.assertQuotedEqual( + curs.mogrify("%s", (obj,)), b"""'{"a": 123}'""") def test_type_not_available(self): - curs = self.conn.cursor() - curs.execute("select oid from pg_type where typname = 'json'") - if curs.fetchone(): - return self.skipTest("json available in test database") + with self.conn.cursor() as curs: + curs.execute("select oid from pg_type where typname = 'json'") + if curs.fetchone(): + return self.skipTest("json available in test database") self.assertRaises(psycopg2.ProgrammingError, psycopg2.extras.register_json, self.conn) @skip_before_postgres(9, 2) def test_default_cast(self): - curs = self.conn.cursor() + with self.conn.cursor() as curs: + curs.execute("""select '{"a": 100.0, "b": null}'::json""") + self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) - curs.execute("""select '{"a": 100.0, "b": null}'::json""") - self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) - - curs.execute("""select array['{"a": 100.0, "b": null}']::json[]""") - self.assertEqual(curs.fetchone()[0], [{'a': 100.0, 'b': None}]) + curs.execute("""select array['{"a": 100.0, "b": null}']::json[]""") + self.assertEqual(curs.fetchone()[0], [{'a': 100.0, 'b': None}]) @skip_if_no_json_type def test_register_on_connection(self): psycopg2.extras.register_json(self.conn) - curs = self.conn.cursor() - curs.execute("""select '{"a": 100.0, "b": null}'::json""") - self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) + with self.conn.cursor() as curs: + curs.execute("""select '{"a": 100.0, "b": null}'::json""") + self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) @skip_if_no_json_type def test_register_on_cursor(self): - curs = self.conn.cursor() - psycopg2.extras.register_json(curs) - curs.execute("""select '{"a": 100.0, "b": null}'::json""") - self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) + with self.conn.cursor() as curs: + psycopg2.extras.register_json(curs) + curs.execute("""select '{"a": 100.0, "b": null}'::json""") + self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) @skip_if_no_json_type @restore_types def test_register_globally(self): new, newa = psycopg2.extras.register_json(self.conn, globally=True) - curs = self.conn.cursor() - curs.execute("""select '{"a": 100.0, "b": null}'::json""") - self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) + with self.conn.cursor() as curs: + curs.execute("""select '{"a": 100.0, "b": null}'::json""") + self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) @skip_if_no_json_type def test_loads(self): @@ -883,9 +877,9 @@ class JsonTestCase(ConnectingTestCase): def loads(s): return json.loads(s, parse_float=Decimal) psycopg2.extras.register_json(self.conn, loads=loads) - curs = self.conn.cursor() - curs.execute("""select '{"a": 100.0, "b": null}'::json""") - data = curs.fetchone()[0] + with self.conn.cursor() as curs: + curs.execute("""select '{"a": 100.0, "b": null}'::json""") + data = curs.fetchone()[0] self.assert_(isinstance(data['a'], Decimal)) self.assertEqual(data['a'], Decimal('100.0')) @@ -899,49 +893,48 @@ class JsonTestCase(ConnectingTestCase): new, newa = psycopg2.extras.register_json( loads=loads, oid=oid, array_oid=array_oid) - curs = self.conn.cursor() - curs.execute("""select '{"a": 100.0, "b": null}'::json""") - data = curs.fetchone()[0] + with self.conn.cursor() as curs: + curs.execute("""select '{"a": 100.0, "b": null}'::json""") + data = curs.fetchone()[0] self.assert_(isinstance(data['a'], Decimal)) self.assertEqual(data['a'], Decimal('100.0')) @skip_before_postgres(9, 2) def test_register_default(self): - curs = self.conn.cursor() + with self.conn.cursor() as curs: + def loads(s): + return psycopg2.extras.json.loads(s, parse_float=Decimal) + psycopg2.extras.register_default_json(curs, loads=loads) - def loads(s): - return psycopg2.extras.json.loads(s, parse_float=Decimal) - psycopg2.extras.register_default_json(curs, loads=loads) + curs.execute("""select '{"a": 100.0, "b": null}'::json""") + data = curs.fetchone()[0] + self.assert_(isinstance(data['a'], Decimal)) + self.assertEqual(data['a'], Decimal('100.0')) - curs.execute("""select '{"a": 100.0, "b": null}'::json""") - data = curs.fetchone()[0] - self.assert_(isinstance(data['a'], Decimal)) - self.assertEqual(data['a'], Decimal('100.0')) - - curs.execute("""select array['{"a": 100.0, "b": null}']::json[]""") - data = curs.fetchone()[0] - self.assert_(isinstance(data[0]['a'], Decimal)) - self.assertEqual(data[0]['a'], Decimal('100.0')) + curs.execute("""select array['{"a": 100.0, "b": null}']::json[]""") + data = curs.fetchone()[0] + self.assert_(isinstance(data[0]['a'], Decimal)) + self.assertEqual(data[0]['a'], Decimal('100.0')) @skip_if_no_json_type def test_null(self): psycopg2.extras.register_json(self.conn) - curs = self.conn.cursor() - curs.execute("""select NULL::json""") - self.assertEqual(curs.fetchone()[0], None) - curs.execute("""select NULL::json[]""") - self.assertEqual(curs.fetchone()[0], None) + with self.conn.cursor() as curs: + curs.execute("""select NULL::json""") + self.assertEqual(curs.fetchone()[0], None) + curs.execute("""select NULL::json[]""") + self.assertEqual(curs.fetchone()[0], None) def test_no_array_oid(self): - curs = self.conn.cursor() - t1, t2 = psycopg2.extras.register_json(curs, oid=25) - self.assertEqual(t1.values[0], 25) - self.assertEqual(t2, None) + with self.conn.cursor() as curs: + t1, t2 = psycopg2.extras.register_json(curs, oid=25) + self.assertEqual(t1.values[0], 25) + self.assertEqual(t2, None) - curs.execute("""select '{"a": 100.0, "b": null}'::text""") - data = curs.fetchone()[0] - self.assertEqual(data['a'], 100) - self.assertEqual(data['b'], None) + curs.execute("""select '{"a": 100.0, "b": null}'::text""") + data = curs.fetchone()[0] + self.assertEqual(data['a'], 100) + self.assertEqual(data['b'], None) def test_str(self): snowman = u"\u2603" @@ -956,20 +949,20 @@ class JsonTestCase(ConnectingTestCase): @skip_before_postgres(8, 2) def test_scs(self): cnn_on = self.connect(options="-c standard_conforming_strings=on") - cur_on = cnn_on.cursor() - self.assertEqual( - cur_on.mogrify("%s", [psycopg2.extras.Json({'a': '"'})]), - b'\'{"a": "\\""}\'') + with cnn_on.cursor() as cur_on: + self.assertEqual( + cur_on.mogrify("%s", [psycopg2.extras.Json({'a': '"'})]), + b'\'{"a": "\\""}\'') cnn_off = self.connect(options="-c standard_conforming_strings=off") - cur_off = cnn_off.cursor() - self.assertEqual( - cur_off.mogrify("%s", [psycopg2.extras.Json({'a': '"'})]), - b'E\'{"a": "\\\\""}\'') + with cnn_off.cursor() as cur_off: + self.assertEqual( + cur_off.mogrify("%s", [psycopg2.extras.Json({'a': '"'})]), + b'E\'{"a": "\\\\""}\'') - self.assertEqual( - cur_on.mogrify("%s", [psycopg2.extras.Json({'a': '"'})]), - b'\'{"a": "\\""}\'') + self.assertEqual( + cur_on.mogrify("%s", [psycopg2.extras.Json({'a': '"'})]), + b'\'{"a": "\\""}\'') def skip_if_no_jsonb_type(f): @@ -985,33 +978,32 @@ class JsonbTestCase(ConnectingTestCase): return rv def test_default_cast(self): - curs = self.conn.cursor() + with self.conn.cursor() as curs: + curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") + self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) - curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") - self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) - - curs.execute("""select array['{"a": 100.0, "b": null}']::jsonb[]""") - self.assertEqual(curs.fetchone()[0], [{'a': 100.0, 'b': None}]) + curs.execute("""select array['{"a": 100.0, "b": null}']::jsonb[]""") + self.assertEqual(curs.fetchone()[0], [{'a': 100.0, 'b': None}]) def test_register_on_connection(self): psycopg2.extras.register_json(self.conn, loads=self.myloads, name='jsonb') - curs = self.conn.cursor() - curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") - self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1}) + with self.conn.cursor() as curs: + curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") + self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1}) def test_register_on_cursor(self): - curs = self.conn.cursor() - psycopg2.extras.register_json(curs, loads=self.myloads, name='jsonb') - curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") - self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1}) + with self.conn.cursor() as curs: + psycopg2.extras.register_json(curs, loads=self.myloads, name='jsonb') + curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") + self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1}) @restore_types def test_register_globally(self): new, newa = psycopg2.extras.register_json(self.conn, loads=self.myloads, globally=True, name='jsonb') - curs = self.conn.cursor() - curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") - self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1}) + with self.conn.cursor() as curs: + curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") + self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1}) def test_loads(self): json = psycopg2.extras.json @@ -1020,41 +1012,40 @@ class JsonbTestCase(ConnectingTestCase): return json.loads(s, parse_float=Decimal) psycopg2.extras.register_json(self.conn, loads=loads, name='jsonb') - curs = self.conn.cursor() - curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") - data = curs.fetchone()[0] - self.assert_(isinstance(data['a'], Decimal)) - self.assertEqual(data['a'], Decimal('100.0')) - # sure we are not manling json too? - curs.execute("""select '{"a": 100.0, "b": null}'::json""") - data = curs.fetchone()[0] - self.assert_(isinstance(data['a'], float)) - self.assertEqual(data['a'], 100.0) + with self.conn.cursor() as curs: + curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") + data = curs.fetchone()[0] + self.assert_(isinstance(data['a'], Decimal)) + self.assertEqual(data['a'], Decimal('100.0')) + # sure we are not manling json too? + curs.execute("""select '{"a": 100.0, "b": null}'::json""") + data = curs.fetchone()[0] + self.assert_(isinstance(data['a'], float)) + self.assertEqual(data['a'], 100.0) def test_register_default(self): - curs = self.conn.cursor() + with self.conn.cursor() as curs: + def loads(s): + return psycopg2.extras.json.loads(s, parse_float=Decimal) - def loads(s): - return psycopg2.extras.json.loads(s, parse_float=Decimal) + psycopg2.extras.register_default_jsonb(curs, loads=loads) - psycopg2.extras.register_default_jsonb(curs, loads=loads) + curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") + data = curs.fetchone()[0] + self.assert_(isinstance(data['a'], Decimal)) + self.assertEqual(data['a'], Decimal('100.0')) - curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") - data = curs.fetchone()[0] - self.assert_(isinstance(data['a'], Decimal)) - self.assertEqual(data['a'], Decimal('100.0')) - - curs.execute("""select array['{"a": 100.0, "b": null}']::jsonb[]""") - data = curs.fetchone()[0] - self.assert_(isinstance(data[0]['a'], Decimal)) - self.assertEqual(data[0]['a'], Decimal('100.0')) + curs.execute("""select array['{"a": 100.0, "b": null}']::jsonb[]""") + data = curs.fetchone()[0] + self.assert_(isinstance(data[0]['a'], Decimal)) + self.assertEqual(data[0]['a'], Decimal('100.0')) def test_null(self): - curs = self.conn.cursor() - curs.execute("""select NULL::jsonb""") - self.assertEqual(curs.fetchone()[0], None) - curs.execute("""select NULL::jsonb[]""") - self.assertEqual(curs.fetchone()[0], None) + with self.conn.cursor() as curs: + curs.execute("""select NULL::jsonb""") + self.assertEqual(curs.fetchone()[0], None) + curs.execute("""select NULL::jsonb[]""") + self.assertEqual(curs.fetchone()[0], None) class RangeTestCase(unittest.TestCase): @@ -1332,59 +1323,59 @@ class RangeCasterTestCase(ConnectingTestCase): 'daterange', 'tsrange', 'tstzrange') def test_cast_null(self): - cur = self.conn.cursor() - for type in self.builtin_ranges: - cur.execute("select NULL::%s" % type) - r = cur.fetchone()[0] - self.assertEqual(r, None) + with self.conn.cursor() as cur: + for type in self.builtin_ranges: + cur.execute("select NULL::%s" % type) + r = cur.fetchone()[0] + self.assertEqual(r, None) def test_cast_empty(self): - cur = self.conn.cursor() - for type in self.builtin_ranges: - cur.execute("select 'empty'::%s" % type) - r = cur.fetchone()[0] - self.assert_(isinstance(r, Range), type) - self.assert_(r.isempty) + with self.conn.cursor() as cur: + for type in self.builtin_ranges: + cur.execute("select 'empty'::%s" % type) + r = cur.fetchone()[0] + self.assert_(isinstance(r, Range), type) + self.assert_(r.isempty) def test_cast_inf(self): - cur = self.conn.cursor() - for type in self.builtin_ranges: - cur.execute("select '(,)'::%s" % type) - r = cur.fetchone()[0] - self.assert_(isinstance(r, Range), type) - self.assert_(not r.isempty) - self.assert_(r.lower_inf) - self.assert_(r.upper_inf) + with self.conn.cursor() as cur: + for type in self.builtin_ranges: + cur.execute("select '(,)'::%s" % type) + r = cur.fetchone()[0] + self.assert_(isinstance(r, Range), type) + self.assert_(not r.isempty) + self.assert_(r.lower_inf) + self.assert_(r.upper_inf) def test_cast_numbers(self): - cur = self.conn.cursor() - for type in ('int4range', 'int8range'): - cur.execute("select '(10,20)'::%s" % type) + with self.conn.cursor() as cur: + for type in ('int4range', 'int8range'): + cur.execute("select '(10,20)'::%s" % type) + r = cur.fetchone()[0] + self.assert_(isinstance(r, NumericRange)) + self.assert_(not r.isempty) + self.assertEqual(r.lower, 11) + self.assertEqual(r.upper, 20) + self.assert_(not r.lower_inf) + self.assert_(not r.upper_inf) + self.assert_(r.lower_inc) + self.assert_(not r.upper_inc) + + cur.execute("select '(10.2,20.6)'::numrange") r = cur.fetchone()[0] self.assert_(isinstance(r, NumericRange)) self.assert_(not r.isempty) - self.assertEqual(r.lower, 11) - self.assertEqual(r.upper, 20) + self.assertEqual(r.lower, Decimal('10.2')) + self.assertEqual(r.upper, Decimal('20.6')) self.assert_(not r.lower_inf) self.assert_(not r.upper_inf) - self.assert_(r.lower_inc) + self.assert_(not r.lower_inc) self.assert_(not r.upper_inc) - cur.execute("select '(10.2,20.6)'::numrange") - r = cur.fetchone()[0] - self.assert_(isinstance(r, NumericRange)) - self.assert_(not r.isempty) - self.assertEqual(r.lower, Decimal('10.2')) - self.assertEqual(r.upper, Decimal('20.6')) - self.assert_(not r.lower_inf) - self.assert_(not r.upper_inf) - self.assert_(not r.lower_inc) - self.assert_(not r.upper_inc) - def test_cast_date(self): - cur = self.conn.cursor() - cur.execute("select '(2000-01-01,2012-12-31)'::daterange") - r = cur.fetchone()[0] + with self.conn.cursor() as cur: + cur.execute("select '(2000-01-01,2012-12-31)'::daterange") + r = cur.fetchone()[0] self.assert_(isinstance(r, DateRange)) self.assert_(not r.isempty) self.assertEqual(r.lower, date(2000, 1, 2)) @@ -1395,11 +1386,11 @@ class RangeCasterTestCase(ConnectingTestCase): self.assert_(not r.upper_inc) def test_cast_timestamp(self): - cur = self.conn.cursor() - ts1 = datetime(2000, 1, 1) - ts2 = datetime(2000, 12, 31, 23, 59, 59, 999) - cur.execute("select tsrange(%s, %s, '()')", (ts1, ts2)) - r = cur.fetchone()[0] + with self.conn.cursor() as cur: + ts1 = datetime(2000, 1, 1) + ts2 = datetime(2000, 12, 31, 23, 59, 59, 999) + cur.execute("select tsrange(%s, %s, '()')", (ts1, ts2)) + r = cur.fetchone()[0] self.assert_(isinstance(r, DateTimeRange)) self.assert_(not r.isempty) self.assertEqual(r.lower, ts1) @@ -1410,12 +1401,12 @@ class RangeCasterTestCase(ConnectingTestCase): self.assert_(not r.upper_inc) def test_cast_timestamptz(self): - cur = self.conn.cursor() - ts1 = datetime(2000, 1, 1, tzinfo=FixedOffsetTimezone(600)) - ts2 = datetime(2000, 12, 31, 23, 59, 59, 999, - tzinfo=FixedOffsetTimezone(600)) - cur.execute("select tstzrange(%s, %s, '[]')", (ts1, ts2)) - r = cur.fetchone()[0] + with self.conn.cursor() as cur: + ts1 = datetime(2000, 1, 1, tzinfo=FixedOffsetTimezone(600)) + ts2 = datetime(2000, 12, 31, 23, 59, 59, 999, + tzinfo=FixedOffsetTimezone(600)) + cur.execute("select tstzrange(%s, %s, '[]')", (ts1, ts2)) + r = cur.fetchone()[0] self.assert_(isinstance(r, DateTimeTZRange)) self.assert_(not r.isempty) self.assertEqual(r.lower, ts1) @@ -1426,202 +1417,199 @@ class RangeCasterTestCase(ConnectingTestCase): self.assert_(r.upper_inc) def test_adapt_number_range(self): - cur = self.conn.cursor() + with self.conn.cursor() as cur: + r = NumericRange(empty=True) + cur.execute("select %s::int4range", (r,)) + r1 = cur.fetchone()[0] + self.assert_(isinstance(r1, NumericRange)) + self.assert_(r1.isempty) - r = NumericRange(empty=True) - cur.execute("select %s::int4range", (r,)) - r1 = cur.fetchone()[0] - self.assert_(isinstance(r1, NumericRange)) - self.assert_(r1.isempty) + r = NumericRange(10, 20) + cur.execute("select %s::int8range", (r,)) + r1 = cur.fetchone()[0] + self.assert_(isinstance(r1, NumericRange)) + self.assertEqual(r1.lower, 10) + self.assertEqual(r1.upper, 20) + self.assert_(r1.lower_inc) + self.assert_(not r1.upper_inc) - r = NumericRange(10, 20) - cur.execute("select %s::int8range", (r,)) - r1 = cur.fetchone()[0] - self.assert_(isinstance(r1, NumericRange)) - self.assertEqual(r1.lower, 10) - self.assertEqual(r1.upper, 20) - self.assert_(r1.lower_inc) - self.assert_(not r1.upper_inc) - - r = NumericRange(Decimal('10.2'), Decimal('20.5'), '(]') - cur.execute("select %s::numrange", (r,)) - r1 = cur.fetchone()[0] - self.assert_(isinstance(r1, NumericRange)) - self.assertEqual(r1.lower, Decimal('10.2')) - self.assertEqual(r1.upper, Decimal('20.5')) - self.assert_(not r1.lower_inc) - self.assert_(r1.upper_inc) + r = NumericRange(Decimal('10.2'), Decimal('20.5'), '(]') + cur.execute("select %s::numrange", (r,)) + r1 = cur.fetchone()[0] + self.assert_(isinstance(r1, NumericRange)) + self.assertEqual(r1.lower, Decimal('10.2')) + self.assertEqual(r1.upper, Decimal('20.5')) + self.assert_(not r1.lower_inc) + self.assert_(r1.upper_inc) def test_adapt_numeric_range(self): - cur = self.conn.cursor() + with self.conn.cursor() as cur: + r = NumericRange(empty=True) + cur.execute("select %s::int4range", (r,)) + r1 = cur.fetchone()[0] + self.assert_(isinstance(r1, NumericRange), r1) + self.assert_(r1.isempty) - r = NumericRange(empty=True) - cur.execute("select %s::int4range", (r,)) - r1 = cur.fetchone()[0] - self.assert_(isinstance(r1, NumericRange), r1) - self.assert_(r1.isempty) + r = NumericRange(10, 20) + cur.execute("select %s::int8range", (r,)) + r1 = cur.fetchone()[0] + self.assert_(isinstance(r1, NumericRange)) + self.assertEqual(r1.lower, 10) + self.assertEqual(r1.upper, 20) + self.assert_(r1.lower_inc) + self.assert_(not r1.upper_inc) - r = NumericRange(10, 20) - cur.execute("select %s::int8range", (r,)) - r1 = cur.fetchone()[0] - self.assert_(isinstance(r1, NumericRange)) - self.assertEqual(r1.lower, 10) - self.assertEqual(r1.upper, 20) - self.assert_(r1.lower_inc) - self.assert_(not r1.upper_inc) - - r = NumericRange(Decimal('10.2'), Decimal('20.5'), '(]') - cur.execute("select %s::numrange", (r,)) - r1 = cur.fetchone()[0] - self.assert_(isinstance(r1, NumericRange)) - self.assertEqual(r1.lower, Decimal('10.2')) - self.assertEqual(r1.upper, Decimal('20.5')) - self.assert_(not r1.lower_inc) - self.assert_(r1.upper_inc) + r = NumericRange(Decimal('10.2'), Decimal('20.5'), '(]') + cur.execute("select %s::numrange", (r,)) + r1 = cur.fetchone()[0] + self.assert_(isinstance(r1, NumericRange)) + self.assertEqual(r1.lower, Decimal('10.2')) + self.assertEqual(r1.upper, Decimal('20.5')) + self.assert_(not r1.lower_inc) + self.assert_(r1.upper_inc) def test_adapt_date_range(self): - cur = self.conn.cursor() + with self.conn.cursor() as cur: + d1 = date(2012, 1, 1) + d2 = date(2012, 12, 31) + r = DateRange(d1, d2) + cur.execute("select %s", (r,)) + r1 = cur.fetchone()[0] + self.assert_(isinstance(r1, DateRange)) + self.assertEqual(r1.lower, d1) + self.assertEqual(r1.upper, d2) + self.assert_(r1.lower_inc) + self.assert_(not r1.upper_inc) - d1 = date(2012, 1, 1) - d2 = date(2012, 12, 31) - r = DateRange(d1, d2) - cur.execute("select %s", (r,)) - r1 = cur.fetchone()[0] - self.assert_(isinstance(r1, DateRange)) - self.assertEqual(r1.lower, d1) - self.assertEqual(r1.upper, d2) - self.assert_(r1.lower_inc) - self.assert_(not r1.upper_inc) + r = DateTimeRange(empty=True) + cur.execute("select %s", (r,)) + r1 = cur.fetchone()[0] + self.assert_(isinstance(r1, DateTimeRange)) + self.assert_(r1.isempty) - r = DateTimeRange(empty=True) - cur.execute("select %s", (r,)) - r1 = cur.fetchone()[0] - self.assert_(isinstance(r1, DateTimeRange)) - self.assert_(r1.isempty) - - ts1 = datetime(2000, 1, 1, tzinfo=FixedOffsetTimezone(600)) - ts2 = datetime(2000, 12, 31, 23, 59, 59, 999, - tzinfo=FixedOffsetTimezone(600)) - r = DateTimeTZRange(ts1, ts2, '(]') - cur.execute("select %s", (r,)) - r1 = cur.fetchone()[0] - self.assert_(isinstance(r1, DateTimeTZRange)) - self.assertEqual(r1.lower, ts1) - self.assertEqual(r1.upper, ts2) - self.assert_(not r1.lower_inc) - self.assert_(r1.upper_inc) + ts1 = datetime(2000, 1, 1, tzinfo=FixedOffsetTimezone(600)) + ts2 = datetime(2000, 12, 31, 23, 59, 59, 999, + tzinfo=FixedOffsetTimezone(600)) + r = DateTimeTZRange(ts1, ts2, '(]') + cur.execute("select %s", (r,)) + r1 = cur.fetchone()[0] + self.assert_(isinstance(r1, DateTimeTZRange)) + self.assertEqual(r1.lower, ts1) + self.assertEqual(r1.upper, ts2) + self.assert_(not r1.lower_inc) + self.assert_(r1.upper_inc) @restore_types def test_register_range_adapter(self): - cur = self.conn.cursor() - cur.execute("create type textrange as range (subtype=text)") - rc = register_range('textrange', 'TextRange', cur) + with self.conn.cursor() as cur: + cur.execute("create type textrange as range (subtype=text)") + rc = register_range('textrange', 'TextRange', cur) - TextRange = rc.range - self.assert_(issubclass(TextRange, Range)) - self.assertEqual(TextRange.__name__, 'TextRange') + TextRange = rc.range + self.assert_(issubclass(TextRange, Range)) + self.assertEqual(TextRange.__name__, 'TextRange') - r = TextRange('a', 'b', '(]') - cur.execute("select %s", (r,)) - r1 = cur.fetchone()[0] - self.assertEqual(r1.lower, 'a') - self.assertEqual(r1.upper, 'b') - self.assert_(not r1.lower_inc) - self.assert_(r1.upper_inc) - - cur.execute("select %s", ([r, r, r],)) - rs = cur.fetchone()[0] - self.assertEqual(len(rs), 3) - for r1 in rs: + r = TextRange('a', 'b', '(]') + cur.execute("select %s", (r,)) + r1 = cur.fetchone()[0] self.assertEqual(r1.lower, 'a') self.assertEqual(r1.upper, 'b') self.assert_(not r1.lower_inc) self.assert_(r1.upper_inc) + cur.execute("select %s", ([r, r, r],)) + rs = cur.fetchone()[0] + self.assertEqual(len(rs), 3) + for r1 in rs: + self.assertEqual(r1.lower, 'a') + self.assertEqual(r1.upper, 'b') + self.assert_(not r1.lower_inc) + self.assert_(r1.upper_inc) + def test_range_escaping(self): - cur = self.conn.cursor() - cur.execute("create type textrange as range (subtype=text)") - rc = register_range('textrange', 'TextRange', cur) + with self.conn.cursor() as cur: + cur.execute("create type textrange as range (subtype=text)") + rc = register_range('textrange', 'TextRange', cur) - TextRange = rc.range - cur.execute(""" - create table rangetest ( - id integer primary key, - range textrange)""") + TextRange = rc.range + cur.execute(""" + create table rangetest ( + id integer primary key, + range textrange)""") - bounds = ['[)', '(]', '()', '[]'] - ranges = [TextRange(low, up, bounds[i % 4]) - for i, (low, up) in enumerate(zip( - [None] + list(map(chr, range(1, 128))), - list(map(chr, range(1, 128))) + [None], - ))] - ranges.append(TextRange()) - ranges.append(TextRange(empty=True)) + bounds = ['[)', '(]', '()', '[]'] + ranges = [TextRange(low, up, bounds[i % 4]) + for i, (low, up) in enumerate(zip( + [None] + list(map(chr, range(1, 128))), + list(map(chr, range(1, 128))) + [None], + ))] + ranges.append(TextRange()) + ranges.append(TextRange(empty=True)) - errs = 0 - for i, r in enumerate(ranges): - # not all the ranges make sense: - # fun fact: select ascii('#') < ascii('$'), '#' < '$' - # yelds... t, f! At least in en_GB.UTF-8 collation. - # which seems suggesting a supremacy of the pound on the dollar. - # So some of these ranges will fail to insert. Be prepared but... - try: - cur.execute(""" - savepoint x; - insert into rangetest (id, range) values (%s, %s); - """, (i, r)) - except psycopg2.DataError: - errs += 1 - cur.execute("rollback to savepoint x;") + errs = 0 + for i, r in enumerate(ranges): + # not all the ranges make sense: + # fun fact: select ascii('#') < ascii('$'), '#' < '$' + # yelds... t, f! At least in en_GB.UTF-8 collation. + # which seems suggesting a supremacy of the pound on the dollar. + # So some of these ranges will fail to insert. Be prepared but... + try: + cur.execute(""" + savepoint x; + insert into rangetest (id, range) values (%s, %s); + """, (i, r)) + except psycopg2.DataError: + errs += 1 + cur.execute("rollback to savepoint x;") - # ...not too many errors! in the above collate there are 17 errors: - # assume in other collates we won't find more than 30 - self.assert_(errs < 30, - "too many collate errors. Is the test working?") + # ...not too many errors! in the above collate there are 17 errors: + # assume in other collates we won't find more than 30 + self.assert_(errs < 30, + "too many collate errors. Is the test working?") - cur.execute("select id, range from rangetest order by id") - for i, r in cur: - self.assertEqual(ranges[i].lower, r.lower) - self.assertEqual(ranges[i].upper, r.upper) - self.assertEqual(ranges[i].lower_inc, r.lower_inc) - self.assertEqual(ranges[i].upper_inc, r.upper_inc) - self.assertEqual(ranges[i].lower_inf, r.lower_inf) - self.assertEqual(ranges[i].upper_inf, r.upper_inf) + cur.execute("select id, range from rangetest order by id") + for i, r in cur: + self.assertEqual(ranges[i].lower, r.lower) + self.assertEqual(ranges[i].upper, r.upper) + self.assertEqual(ranges[i].lower_inc, r.lower_inc) + self.assertEqual(ranges[i].upper_inc, r.upper_inc) + self.assertEqual(ranges[i].lower_inf, r.lower_inf) + self.assertEqual(ranges[i].upper_inf, r.upper_inf) # clear the adapters to allow precise count by scripts/refcounter.py del ext.adapters[TextRange, ext.ISQLQuote] def test_range_not_found(self): - cur = self.conn.cursor() - self.assertRaises(psycopg2.ProgrammingError, - register_range, 'nosuchrange', 'FailRange', cur) + with self.conn.cursor() as cur: + self.assertRaises(psycopg2.ProgrammingError, + register_range, 'nosuchrange', 'FailRange', cur) @restore_types def test_schema_range(self): - cur = self.conn.cursor() - cur.execute("create schema rs") - cur.execute("create type r1 as range (subtype=text)") - cur.execute("create type r2 as range (subtype=text)") - cur.execute("create type rs.r2 as range (subtype=text)") - cur.execute("create type rs.r3 as range (subtype=text)") - cur.execute("savepoint x") + with self.conn.cursor() as cur: + cur.execute("create schema rs") + cur.execute("create type r1 as range (subtype=text)") + cur.execute("create type r2 as range (subtype=text)") + cur.execute("create type rs.r2 as range (subtype=text)") + cur.execute("create type rs.r3 as range (subtype=text)") + cur.execute("savepoint x") - register_range('r1', 'r1', cur) - ra2 = register_range('r2', 'r2', cur) - rars2 = register_range('rs.r2', 'r2', cur) - register_range('rs.r3', 'r3', cur) + register_range('r1', 'r1', cur) + ra2 = register_range('r2', 'r2', cur) + rars2 = register_range('rs.r2', 'r2', cur) + register_range('rs.r3', 'r3', cur) - self.assertNotEqual( - ra2.typecaster.values[0], - rars2.typecaster.values[0]) + self.assertNotEqual( + ra2.typecaster.values[0], + rars2.typecaster.values[0]) - self.assertRaises(psycopg2.ProgrammingError, - register_range, 'r3', 'FailRange', cur) - cur.execute("rollback to savepoint x;") + self.assertRaises(psycopg2.ProgrammingError, + register_range, 'r3', 'FailRange', cur) + cur.execute("rollback to savepoint x;") - self.assertRaises(psycopg2.ProgrammingError, - register_range, 'rs.r1', 'FailRange', cur) - cur.execute("rollback to savepoint x;") + self.assertRaises(psycopg2.ProgrammingError, + register_range, 'rs.r1', 'FailRange', cur) + cur.execute("rollback to savepoint x;") class TestSolveConnCurs(ConnectingTestCase): diff --git a/tests/test_with.py b/tests/test_with.py index b8c043f6..13173452 100755 --- a/tests/test_with.py +++ b/tests/test_with.py @@ -33,15 +33,15 @@ from .testutils import ConnectingTestCase, skip_before_postgres class WithTestCase(ConnectingTestCase): def setUp(self): ConnectingTestCase.setUp(self) - curs = self.conn.cursor() - try: - curs.execute("delete from test_with") - self.conn.commit() - except psycopg2.ProgrammingError: - # assume table doesn't exist - self.conn.rollback() - curs.execute("create table test_with (id integer primary key)") - self.conn.commit() + with self.conn.cursor() as curs: + try: + curs.execute("delete from test_with") + self.conn.commit() + except psycopg2.ProgrammingError: + # assume table doesn't exist + self.conn.rollback() + curs.execute("create table test_with (id integer primary key)") + self.conn.commit() class WithConnectionTestCase(WithTestCase): @@ -49,59 +49,59 @@ class WithConnectionTestCase(WithTestCase): with self.conn as conn: self.assert_(self.conn is conn) self.assertEqual(conn.status, ext.STATUS_READY) - curs = conn.cursor() - curs.execute("insert into test_with values (1)") + with conn.cursor() as curs: + curs.execute("insert into test_with values (1)") self.assertEqual(conn.status, ext.STATUS_BEGIN) self.assertEqual(self.conn.status, ext.STATUS_READY) self.assert_(not self.conn.closed) - curs = self.conn.cursor() - curs.execute("select * from test_with") - self.assertEqual(curs.fetchall(), [(1,)]) + with self.conn.cursor() as curs: + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), [(1,)]) def test_with_connect_idiom(self): with self.connect() as conn: self.assertEqual(conn.status, ext.STATUS_READY) - curs = conn.cursor() - curs.execute("insert into test_with values (2)") - self.assertEqual(conn.status, ext.STATUS_BEGIN) + with conn.cursor() as curs: + curs.execute("insert into test_with values (2)") + self.assertEqual(conn.status, ext.STATUS_BEGIN) self.assertEqual(self.conn.status, ext.STATUS_READY) self.assert_(not self.conn.closed) - curs = self.conn.cursor() - curs.execute("select * from test_with") - self.assertEqual(curs.fetchall(), [(2,)]) + with self.conn.cursor() as curs: + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), [(2,)]) def test_with_error_db(self): def f(): with self.conn as conn: - curs = conn.cursor() - curs.execute("insert into test_with values ('a')") + with conn.cursor() as curs: + curs.execute("insert into test_with values ('a')") self.assertRaises(psycopg2.DataError, f) self.assertEqual(self.conn.status, ext.STATUS_READY) self.assert_(not self.conn.closed) - curs = self.conn.cursor() - curs.execute("select * from test_with") - self.assertEqual(curs.fetchall(), []) + with self.conn.cursor() as curs: + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), []) def test_with_error_python(self): def f(): with self.conn as conn: - curs = conn.cursor() - curs.execute("insert into test_with values (3)") - 1 / 0 + with conn.cursor() as curs: + curs.execute("insert into test_with values (3)") + 1 / 0 self.assertRaises(ZeroDivisionError, f) self.assertEqual(self.conn.status, ext.STATUS_READY) self.assert_(not self.conn.closed) - curs = self.conn.cursor() - curs.execute("select * from test_with") - self.assertEqual(curs.fetchall(), []) + with self.conn.cursor() as curs: + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), []) def test_with_closed(self): def f(): @@ -120,15 +120,15 @@ class WithConnectionTestCase(WithTestCase): super(MyConn, self).commit() with self.connect(connection_factory=MyConn) as conn: - curs = conn.cursor() - curs.execute("insert into test_with values (10)") + with conn.cursor() as curs: + curs.execute("insert into test_with values (10)") self.assertEqual(conn.status, ext.STATUS_READY) self.assert_(commits) - curs = self.conn.cursor() - curs.execute("select * from test_with") - self.assertEqual(curs.fetchall(), [(10,)]) + with self.conn.cursor() as curs: + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), [(10,)]) def test_subclass_rollback(self): rollbacks = [] @@ -140,9 +140,9 @@ class WithConnectionTestCase(WithTestCase): try: with self.connect(connection_factory=MyConn) as conn: - curs = conn.cursor() - curs.execute("insert into test_with values (11)") - 1 / 0 + with conn.cursor() as curs: + curs.execute("insert into test_with values (11)") + 1 / 0 except ZeroDivisionError: pass else: @@ -151,9 +151,9 @@ class WithConnectionTestCase(WithTestCase): self.assertEqual(conn.status, ext.STATUS_READY) self.assert_(rollbacks) - curs = conn.cursor() - curs.execute("select * from test_with") - self.assertEqual(curs.fetchall(), []) + with conn.cursor() as curs: + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), []) class WithCursorTestCase(WithTestCase): @@ -168,9 +168,9 @@ class WithCursorTestCase(WithTestCase): self.assertEqual(self.conn.status, ext.STATUS_READY) self.assert_(not self.conn.closed) - curs = self.conn.cursor() - curs.execute("select * from test_with") - self.assertEqual(curs.fetchall(), [(4,)]) + with self.conn.cursor() as curs: + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), [(4,)]) def test_with_error(self): try: @@ -185,9 +185,9 @@ class WithCursorTestCase(WithTestCase): self.assert_(not self.conn.closed) self.assert_(curs.closed) - curs = self.conn.cursor() - curs.execute("select * from test_with") - self.assertEqual(curs.fetchall(), []) + with self.conn.cursor() as curs: + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), []) def test_subclass(self): closes = [] diff --git a/tests/testutils.py b/tests/testutils.py index cd69c468..b0770b34 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -228,9 +228,9 @@ def skip_if_no_uuid(f): @wraps(f) def skip_if_no_uuid_(self): try: - cur = self.conn.cursor() - cur.execute("select typname from pg_type where typname = 'uuid'") - has = cur.fetchone() + with self.conn.cursor() as cur: + cur.execute("select typname from pg_type where typname = 'uuid'") + has = cur.fetchone() finally: self.conn.rollback() @@ -249,14 +249,14 @@ def skip_if_tpc_disabled(f): def skip_if_tpc_disabled_(self): cnn = self.connect() try: - cur = cnn.cursor() - try: - cur.execute("SHOW max_prepared_transactions;") - except psycopg2.ProgrammingError: - return self.skipTest( - "server too old: two phase transactions not supported.") - else: - mtp = int(cur.fetchone()[0]) + with cnn.cursor() as cur: + try: + cur.execute("SHOW max_prepared_transactions;") + except psycopg2.ProgrammingError: + return self.skipTest( + "server too old: two phase transactions not supported.") + else: + mtp = int(cur.fetchone()[0]) finally: cnn.close() From 280eefd2f406161b015f681fbacc631973b94e1a Mon Sep 17 00:00:00 2001 From: Jon Dufresne Date: Sun, 2 Feb 2020 16:20:52 -0800 Subject: [PATCH 4/8] Emit a warning when a connection or cursor isn't closed This follows the semantics of Python file objects. When the object is garbage collected and it has not been closed, emit a ResourceWarning. This allows library users to notice inconsistent and non-deterministic resource management and fix it. Users should use context managers and finally blocks to always close cursors and connections when they are not longer required. For example, in Python, not closing a file results in the following: $ python3 -Walways >>> f = open('foo', 'w') >>> del f __main__:1: ResourceWarning: unclosed file <_io.TextIOWrapper name='foo' mode='w' encoding='UTF-8'> ResourceWarning: Enable tracemalloc to get the object allocation traceback psycopg now acts the same way: $ python3 -Walways >>> import psycopg2 >>> c = psycopg2.connect(database='psycopg2_test') >>> del c :1: ResourceWarning: unclosed connection ResourceWarning: Enable tracemalloc to get the object allocation traceback All warnings noticed during testing has been fixed. --- psycopg/connection_type.c | 73 ++++++++++++++++++++++++++++++++++----- psycopg/cursor_type.c | 42 ++++++++++++++++++++-- psycopg/python.h | 4 +++ tests/__init__.py | 2 ++ tests/test_warnings.py | 63 +++++++++++++++++++++++++++++++++ 5 files changed, 174 insertions(+), 10 deletions(-) create mode 100644 tests/test_warnings.py diff --git a/psycopg/connection_type.c b/psycopg/connection_type.c index 25299fab..5caf644d 100644 --- a/psycopg/connection_type.c +++ b/psycopg/connection_type.c @@ -59,7 +59,7 @@ static PyObject * psyco_conn_cursor(connectionObject *self, PyObject *args, PyObject *kwargs) { PyObject *obj = NULL; - PyObject *rv = NULL; + cursorObject *curs = NULL; PyObject *name = Py_None; PyObject *factory = Py_None; PyObject *withhold = Py_False; @@ -110,14 +110,17 @@ psyco_conn_cursor(connectionObject *self, PyObject *args, PyObject *kwargs) if (PyObject_IsInstance(obj, (PyObject *)&cursorType) == 0) { PyErr_SetString(PyExc_TypeError, "cursor factory must be subclass of psycopg2.extensions.cursor"); + Py_DECREF(obj); + obj = NULL; goto exit; } - if (0 > curs_withhold_set((cursorObject *)obj, withhold)) { - goto exit; + curs = (cursorObject *)obj; + if (0 > curs_withhold_set(curs, withhold)) { + goto error; } - if (0 > curs_scrollable_set((cursorObject *)obj, scrollable)) { - goto exit; + if (0 > curs_scrollable_set(curs, scrollable)) { + goto error; } Dprintf("psyco_conn_cursor: new cursor at %p: refcnt = " @@ -125,12 +128,26 @@ psyco_conn_cursor(connectionObject *self, PyObject *args, PyObject *kwargs) obj, Py_REFCNT(obj) ); - rv = obj; obj = NULL; + goto exit; + +error: + { + PyObject *error_type, *error_value, *error_traceback; + PyObject *close; + curs = NULL; + PyErr_Fetch(&error_type, &error_value, &error_traceback); + close = PyObject_CallMethod(obj, "close", NULL); + if (close) + Py_DECREF(close); + else + PyErr_WriteUnraisable(obj); + PyErr_Restore(error_type, error_value, error_traceback); + } exit: Py_XDECREF(obj); - return rv; + return (PyObject *)curs; } @@ -1366,6 +1383,9 @@ connection_dealloc(PyObject* obj) { connectionObject *self = (connectionObject *)obj; + if (PyObject_CallFinalizerFromDealloc(obj) < 0) + return; + /* Make sure to untrack the connection before calling conn_close, which may * allow a different thread to try and dealloc the connection again, * resulting in a double-free segfault (ticket #166). */ @@ -1405,6 +1425,31 @@ connection_dealloc(PyObject* obj) Py_TYPE(obj)->tp_free(obj); } +#if PY_3 +static void +connection_finalize(PyObject *obj) +{ + connectionObject *self = (connectionObject *)obj; + +#ifdef CONN_CHECK_PID + if (self->procpid == getpid()) +#endif + { + if (!self->closed) { + PyObject *error_type, *error_value, *error_traceback; + /* Save the current exception, if any. */ + PyErr_Fetch(&error_type, &error_value, &error_traceback); + + if (PyErr_WarnFormat(PyExc_ResourceWarning, 1, "unclosed connection %R", obj)) + PyErr_WriteUnraisable(obj); + + /* Restore the saved exception. */ + PyErr_Restore(error_type, error_value, error_traceback); + } + } +} +#endif + static int connection_init(PyObject *obj, PyObject *args, PyObject *kwds) { @@ -1479,7 +1524,7 @@ PyTypeObject connectionType = { 0, /*tp_setattro*/ 0, /*tp_as_buffer*/ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC | - Py_TPFLAGS_HAVE_WEAKREFS, + Py_TPFLAGS_HAVE_WEAKREFS | Py_TPFLAGS_HAVE_FINALIZE, /*tp_flags*/ connectionType_doc, /*tp_doc*/ (traverseproc)connection_traverse, /*tp_traverse*/ @@ -1499,4 +1544,16 @@ PyTypeObject connectionType = { connection_init, /*tp_init*/ 0, /*tp_alloc*/ connection_new, /*tp_new*/ + 0, /* tp_free */ + 0, /* tp_is_gc */ + 0, /* tp_bases */ + 0, /* tp_mro */ + 0, /* tp_cache */ + 0, /* tp_subclasses */ + 0, /* tp_weaklist */ +#if PY_3 + 0, /* tp_del */ + 0, /* tp_version_tag */ + connection_finalize, /* tp_finalize */ +#endif }; diff --git a/psycopg/cursor_type.c b/psycopg/cursor_type.c index a7bd11b4..4ccbc135 100644 --- a/psycopg/cursor_type.c +++ b/psycopg/cursor_type.c @@ -39,6 +39,7 @@ #include +#define cursor_closed(self) ((self)->closed || ((self)->conn && (self)->conn->closed)) /** DBAPI methods **/ @@ -1637,7 +1638,7 @@ exit: static PyObject * curs_closed_get(cursorObject *self, void *closure) { - return PyBool_FromLong(self->closed || (self->conn && self->conn->closed)); + return PyBool_FromLong(cursor_closed(self)); } /* extension: withhold - get or set "WITH HOLD" for named cursors */ @@ -1945,6 +1946,9 @@ cursor_dealloc(PyObject* obj) { cursorObject *self = (cursorObject *)obj; + if (PyObject_CallFinalizerFromDealloc(obj) < 0) + return; + PyObject_GC_UnTrack(self); if (self->weakreflist) { @@ -1965,6 +1969,28 @@ cursor_dealloc(PyObject* obj) Py_TYPE(obj)->tp_free(obj); } +#if PY_3 +static void +cursor_finalize(PyObject *obj) +{ + cursorObject *self = (cursorObject *)obj; + + if (!cursor_closed(self)) { + PyObject *error_type, *error_value, *error_traceback; + /* Save the current exception, if any. */ + PyErr_Fetch(&error_type, &error_value, &error_traceback); + + if (PyErr_WarnFormat(PyExc_ResourceWarning, 1, + "unclosed cursor %R for connection %R", + obj, (PyObject *)self->conn)) + PyErr_WriteUnraisable(obj); + + /* Restore the saved exception. */ + PyErr_Restore(error_type, error_value, error_traceback); + } +} +#endif + static int cursor_init(PyObject *obj, PyObject *args, PyObject *kwargs) { @@ -2056,7 +2082,7 @@ PyTypeObject cursorType = { 0, /*tp_setattro*/ 0, /*tp_as_buffer*/ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_ITER | - Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_HAVE_WEAKREFS , + Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_HAVE_WEAKREFS | Py_TPFLAGS_HAVE_FINALIZE, /*tp_flags*/ cursorType_doc, /*tp_doc*/ (traverseproc)cursor_traverse, /*tp_traverse*/ @@ -2076,4 +2102,16 @@ PyTypeObject cursorType = { cursor_init, /*tp_init*/ 0, /*tp_alloc*/ cursor_new, /*tp_new*/ + 0, /* tp_free */ + 0, /* tp_is_gc */ + 0, /* tp_bases */ + 0, /* tp_mro */ + 0, /* tp_cache */ + 0, /* tp_subclasses */ + 0, /* tp_weaklist */ +#if PY_3 + 0, /* tp_del */ + 0, /* tp_version_tag */ + cursor_finalize, /* tp_finalize */ +#endif }; diff --git a/psycopg/python.h b/psycopg/python.h index 2a5f9d83..a38231fc 100644 --- a/psycopg/python.h +++ b/psycopg/python.h @@ -86,6 +86,8 @@ typedef unsigned long Py_uhash_t; #define Bytes_ConcatAndDel PyString_ConcatAndDel #define _Bytes_Resize _PyString_Resize +#define Py_TPFLAGS_HAVE_FINALIZE 0L + #define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days) #define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds) #define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds) @@ -97,6 +99,8 @@ typedef unsigned long Py_uhash_t; PyLong_FromUnsignedLong((unsigned long)(x)) : \ PyInt_FromLong((x))) +#define PyObject_CallFinalizerFromDealloc(obj) 0 + #endif /* PY_2 */ #if PY_3 diff --git a/tests/__init__.py b/tests/__init__.py index f5c422f4..a785f8bf 100755 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -56,6 +56,7 @@ from . import test_sql from . import test_transaction from . import test_types_basic from . import test_types_extras +from . import test_warnings from . import test_with if sys.version_info[:2] < (3, 6): @@ -101,6 +102,7 @@ def test_suite(): suite.addTest(test_transaction.test_suite()) suite.addTest(test_types_basic.test_suite()) suite.addTest(test_types_extras.test_suite()) + suite.addTest(test_warnings.test_suite()) suite.addTest(test_with.test_suite()) return suite diff --git a/tests/test_warnings.py b/tests/test_warnings.py new file mode 100644 index 00000000..0644fb10 --- /dev/null +++ b/tests/test_warnings.py @@ -0,0 +1,63 @@ +import unittest +import warnings + +import psycopg2 + +from .testconfig import dsn +from .testutils import skip_before_python + + +class WarningsTest(unittest.TestCase): + @skip_before_python(3) + def test_connection_not_closed(self): + def f(): + psycopg2.connect(dsn) + + msg = ( + "^unclosed connection $" + ) + with self.assertWarnsRegex(ResourceWarning, msg): + f() + + @skip_before_python(3) + def test_cursor_not_closed(self): + def f(): + conn = psycopg2.connect(dsn) + try: + conn.cursor() + finally: + conn.close() + + msg = ( + "^unclosed cursor for " + "connection $" + ) + with self.assertWarnsRegex(ResourceWarning, msg): + f() + + def test_cursor_factory_returns_non_cursor(self): + def bad_factory(*args, **kwargs): + return object() + + def f(): + conn = psycopg2.connect(dsn) + try: + conn.cursor(cursor_factory=bad_factory) + finally: + conn.close() + + with warnings.catch_warnings(record=True) as cm: + with self.assertRaises(TypeError): + f() + + # No warning as no cursor was instantiated. + self.assertEquals(cm, []) + + +def test_suite(): + return unittest.TestLoader().loadTestsFromName(__name__) + + +if __name__ == "__main__": + unittest.main() From 6bb6f42169bcaccda9ba98a394fa937451a38657 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sun, 9 Feb 2020 15:26:13 +0000 Subject: [PATCH 5/8] Clearer object ownership in psyco_conn_curs Make sure the object is assigned only to one variable at time, across obj, curs, rv (we create it on obj, pass it to curs when we know the type is right, pass it to rv when we know there was no error). --- psycopg/connection_type.c | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/psycopg/connection_type.c b/psycopg/connection_type.c index 5caf644d..fa5c4c8e 100644 --- a/psycopg/connection_type.c +++ b/psycopg/connection_type.c @@ -60,6 +60,7 @@ psyco_conn_cursor(connectionObject *self, PyObject *args, PyObject *kwargs) { PyObject *obj = NULL; cursorObject *curs = NULL; + PyObject *rv = NULL; PyObject *name = Py_None; PyObject *factory = Py_None; PyObject *withhold = Py_False; @@ -110,12 +111,13 @@ psyco_conn_cursor(connectionObject *self, PyObject *args, PyObject *kwargs) if (PyObject_IsInstance(obj, (PyObject *)&cursorType) == 0) { PyErr_SetString(PyExc_TypeError, "cursor factory must be subclass of psycopg2.extensions.cursor"); - Py_DECREF(obj); - obj = NULL; goto exit; } + /* pass ownership from obj to curs */ curs = (cursorObject *)obj; + obj = NULL; + if (0 > curs_withhold_set(curs, withhold)) { goto error; } @@ -128,26 +130,31 @@ psyco_conn_cursor(connectionObject *self, PyObject *args, PyObject *kwargs) obj, Py_REFCNT(obj) ); - obj = NULL; + /* pass ownership from curs to rv */ + rv = (PyObject *)curs; + curs = NULL; + goto exit; error: { PyObject *error_type, *error_value, *error_traceback; PyObject *close; - curs = NULL; PyErr_Fetch(&error_type, &error_value, &error_traceback); - close = PyObject_CallMethod(obj, "close", NULL); - if (close) - Py_DECREF(close); - else - PyErr_WriteUnraisable(obj); + if (curs) { + close = PyObject_CallMethod((PyObject *)curs, "close", NULL); + if (close) + Py_DECREF(close); + else + PyErr_WriteUnraisable((PyObject *)curs); + } PyErr_Restore(error_type, error_value, error_traceback); } exit: Py_XDECREF(obj); - return (PyObject *)curs; + Py_XDECREF(curs); + return rv; } From 1e0077a2b41794fc8f3a3349b6baaea36e72e540 Mon Sep 17 00:00:00 2001 From: Jon Dufresne Date: Tue, 11 Feb 2020 17:17:15 -0800 Subject: [PATCH 6/8] Fix curs references --- psycopg/connection_type.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/psycopg/connection_type.c b/psycopg/connection_type.c index fa5c4c8e..5ba0349c 100644 --- a/psycopg/connection_type.c +++ b/psycopg/connection_type.c @@ -127,7 +127,7 @@ psyco_conn_cursor(connectionObject *self, PyObject *args, PyObject *kwargs) Dprintf("psyco_conn_cursor: new cursor at %p: refcnt = " FORMAT_CODE_PY_SSIZE_T, - obj, Py_REFCNT(obj) + curs, Py_REFCNT(curs) ); /* pass ownership from curs to rv */ From d5110e5191f59113fc122fee1ac535c8e07a8344 Mon Sep 17 00:00:00 2001 From: Jon Dufresne Date: Tue, 11 Feb 2020 17:32:57 -0800 Subject: [PATCH 7/8] Allow capital hex digits --- tests/test_warnings.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_warnings.py b/tests/test_warnings.py index 0644fb10..3e09c595 100644 --- a/tests/test_warnings.py +++ b/tests/test_warnings.py @@ -14,7 +14,7 @@ class WarningsTest(unittest.TestCase): psycopg2.connect(dsn) msg = ( - "^unclosed connection $" ) with self.assertWarnsRegex(ResourceWarning, msg): @@ -30,8 +30,8 @@ class WarningsTest(unittest.TestCase): conn.close() msg = ( - "^unclosed cursor for " - "connection $" + "^unclosed cursor for " + "connection $" ) with self.assertWarnsRegex(ResourceWarning, msg): f() From e5322a71bb904788845e2321e204038464c8434d Mon Sep 17 00:00:00 2001 From: Jon Dufresne Date: Tue, 11 Feb 2020 17:55:37 -0800 Subject: [PATCH 8/8] Add test for ignored exception when .close() fails --- tests/test_warnings.py | 53 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/tests/test_warnings.py b/tests/test_warnings.py index 3e09c595..cf003de1 100644 --- a/tests/test_warnings.py +++ b/tests/test_warnings.py @@ -1,3 +1,6 @@ +import re +import subprocess +import sys import unittest import warnings @@ -54,6 +57,56 @@ class WarningsTest(unittest.TestCase): # No warning as no cursor was instantiated. self.assertEquals(cm, []) + @skip_before_python(3) + def test_broken_close(self): + script = """ +import psycopg2 + +class MyException(Exception): + pass + +class MyCurs(psycopg2.extensions.cursor): + def close(self): + raise MyException + +def f(): + conn = psycopg2.connect(%(dsn)r) + try: + conn.cursor(cursor_factory=MyCurs, scrollable=True) + finally: + conn.close() + +f() +""" % {"dsn": dsn} + p = subprocess.Popen( + [sys.executable, "-Walways", "-c", script], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + output, _ = p.communicate() + output = output.decode() + # Normalize line endings. + output = "\n".join(output.splitlines()) + self.assertRegex( + output, + re.compile( + r"^Exception ignored in: " + r"$", + re.M, + ), + ) + self.assertIn("\n__main__.MyException: \n", output) + self.assertRegex( + output, + re.compile( + r"ResourceWarning: unclosed cursor " + r" " + r"for connection " + r"$", + re.M, + ), + ) + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__)