diff --git a/tests/test_async.py b/tests/test_async.py index 8624be4a..7a237b5e 100755 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -61,16 +61,18 @@ class AsyncTests(ConnectingTestCase): self.wait(self.conn) - curs = self.conn.cursor() - curs.execute(''' - CREATE TEMPORARY TABLE table1 ( - id int PRIMARY KEY - )''') - self.wait(curs) + with self.conn.cursor() as curs: + curs.execute(''' + CREATE TEMPORARY TABLE table1 ( + id int PRIMARY KEY + )''') + self.wait(curs) def test_connection_setup(self): cur = self.conn.cursor() sync_cur = self.sync_conn.cursor() + cur.close() + sync_cur.close() del cur, sync_cur self.assert_(self.conn.async_) @@ -90,159 +92,156 @@ class AsyncTests(ConnectingTestCase): self.conn.cursor, "name") def test_async_select(self): - cur = self.conn.cursor() - self.assertFalse(self.conn.isexecuting()) - cur.execute("select 'a'") - self.assertTrue(self.conn.isexecuting()) + with self.conn.cursor() as cur: + self.assertFalse(self.conn.isexecuting()) + cur.execute("select 'a'") + self.assertTrue(self.conn.isexecuting()) - self.wait(cur) + self.wait(cur) - self.assertFalse(self.conn.isexecuting()) - self.assertEquals(cur.fetchone()[0], "a") + self.assertFalse(self.conn.isexecuting()) + self.assertEquals(cur.fetchone()[0], "a") @slow @skip_before_postgres(8, 2) def test_async_callproc(self): - cur = self.conn.cursor() - cur.callproc("pg_sleep", (0.1, )) - self.assertTrue(self.conn.isexecuting()) + with self.conn.cursor() as cur: + cur.callproc("pg_sleep", (0.1, )) + self.assertTrue(self.conn.isexecuting()) - self.wait(cur) - self.assertFalse(self.conn.isexecuting()) - self.assertEquals(cur.fetchall()[0][0], '') + self.wait(cur) + self.assertFalse(self.conn.isexecuting()) + self.assertEquals(cur.fetchall()[0][0], '') @slow def test_async_after_async(self): - cur = self.conn.cursor() - cur2 = self.conn.cursor() - del cur2 + with self.conn.cursor() as cur: + cur2 = self.conn.cursor() + cur2.close() + del cur2 - cur.execute("insert into table1 values (1)") + cur.execute("insert into table1 values (1)") - # an async execute after an async one raises an exception - self.assertRaises(psycopg2.ProgrammingError, - cur.execute, "select * from table1") - # same for callproc - self.assertRaises(psycopg2.ProgrammingError, - cur.callproc, "version") - # but after you've waited it should be good - self.wait(cur) - cur.execute("select * from table1") - self.wait(cur) + # an async execute after an async one raises an exception + self.assertRaises(psycopg2.ProgrammingError, + cur.execute, "select * from table1") + # same for callproc + self.assertRaises(psycopg2.ProgrammingError, + cur.callproc, "version") + # but after you've waited it should be good + self.wait(cur) + cur.execute("select * from table1") + self.wait(cur) - self.assertEquals(cur.fetchall()[0][0], 1) + self.assertEquals(cur.fetchall()[0][0], 1) - cur.execute("delete from table1") - self.wait(cur) + cur.execute("delete from table1") + self.wait(cur) - cur.execute("select * from table1") - self.wait(cur) + cur.execute("select * from table1") + self.wait(cur) - self.assertEquals(cur.fetchone(), None) + self.assertEquals(cur.fetchone(), None) def test_fetch_after_async(self): - cur = self.conn.cursor() - cur.execute("select 'a'") + with self.conn.cursor() as cur: + cur.execute("select 'a'") - # a fetch after an asynchronous query should raise an error - self.assertRaises(psycopg2.ProgrammingError, - cur.fetchall) - # but after waiting it should work - self.wait(cur) - self.assertEquals(cur.fetchall()[0][0], "a") + # a fetch after an asynchronous query should raise an error + self.assertRaises(psycopg2.ProgrammingError, + cur.fetchall) + # but after waiting it should work + self.wait(cur) + self.assertEquals(cur.fetchall()[0][0], "a") def test_rollback_while_async(self): - cur = self.conn.cursor() + with self.conn.cursor() as cur: + cur.execute("select 'a'") - cur.execute("select 'a'") - - # a rollback should not work in asynchronous mode - self.assertRaises(psycopg2.ProgrammingError, self.conn.rollback) + # a rollback should not work in asynchronous mode + self.assertRaises(psycopg2.ProgrammingError, self.conn.rollback) def test_commit_while_async(self): - cur = self.conn.cursor() + with self.conn.cursor() as cur: + cur.execute("begin") + self.wait(cur) - cur.execute("begin") - self.wait(cur) + cur.execute("insert into table1 values (1)") - cur.execute("insert into table1 values (1)") + # a commit should not work in asynchronous mode + self.assertRaises(psycopg2.ProgrammingError, self.conn.commit) + self.assertTrue(self.conn.isexecuting()) - # a commit should not work in asynchronous mode - self.assertRaises(psycopg2.ProgrammingError, self.conn.commit) - self.assertTrue(self.conn.isexecuting()) + # but a manual commit should + self.wait(cur) + cur.execute("commit") + self.wait(cur) - # but a manual commit should - self.wait(cur) - cur.execute("commit") - self.wait(cur) + cur.execute("select * from table1") + self.wait(cur) + self.assertEquals(cur.fetchall()[0][0], 1) - cur.execute("select * from table1") - self.wait(cur) - self.assertEquals(cur.fetchall()[0][0], 1) + cur.execute("delete from table1") + self.wait(cur) - cur.execute("delete from table1") - self.wait(cur) - - cur.execute("select * from table1") - self.wait(cur) - self.assertEquals(cur.fetchone(), None) + cur.execute("select * from table1") + self.wait(cur) + self.assertEquals(cur.fetchone(), None) def test_set_parameters_while_async(self): - cur = self.conn.cursor() + with self.conn.cursor() as cur: + cur.execute("select 'c'") + self.assertTrue(self.conn.isexecuting()) - cur.execute("select 'c'") - self.assertTrue(self.conn.isexecuting()) + # getting transaction status works + self.assertEquals(self.conn.info.transaction_status, + ext.TRANSACTION_STATUS_ACTIVE) + self.assertTrue(self.conn.isexecuting()) - # getting transaction status works - self.assertEquals(self.conn.info.transaction_status, - ext.TRANSACTION_STATUS_ACTIVE) - self.assertTrue(self.conn.isexecuting()) + # setting connection encoding should fail + self.assertRaises(psycopg2.ProgrammingError, + self.conn.set_client_encoding, "LATIN1") - # setting connection encoding should fail - self.assertRaises(psycopg2.ProgrammingError, - self.conn.set_client_encoding, "LATIN1") - - # same for transaction isolation - self.assertRaises(psycopg2.ProgrammingError, - self.conn.set_isolation_level, 1) + # same for transaction isolation + self.assertRaises(psycopg2.ProgrammingError, + self.conn.set_isolation_level, 1) def test_reset_while_async(self): - cur = self.conn.cursor() - cur.execute("select 'c'") - self.assertTrue(self.conn.isexecuting()) + with self.conn.cursor() as cur: + cur.execute("select 'c'") + self.assertTrue(self.conn.isexecuting()) - # a reset should fail - self.assertRaises(psycopg2.ProgrammingError, self.conn.reset) + # a reset should fail + self.assertRaises(psycopg2.ProgrammingError, self.conn.reset) def test_async_iter(self): - cur = self.conn.cursor() + with self.conn.cursor() as cur: + cur.execute("begin") + self.wait(cur) + cur.execute(""" + insert into table1 values (1); + insert into table1 values (2); + insert into table1 values (3); + """) + self.wait(cur) + cur.execute("select id from table1 order by id") - cur.execute("begin") - self.wait(cur) - cur.execute(""" - insert into table1 values (1); - insert into table1 values (2); - insert into table1 values (3); - """) - self.wait(cur) - cur.execute("select id from table1 order by id") + # iteration fails if a query is underway + self.assertRaises(psycopg2.ProgrammingError, list, cur) - # iteration fails if a query is underway - self.assertRaises(psycopg2.ProgrammingError, list, cur) - - # but after it's done it should work - self.wait(cur) - self.assertEquals(list(cur), [(1, ), (2, ), (3, )]) - self.assertFalse(self.conn.isexecuting()) + # but after it's done it should work + self.wait(cur) + self.assertEquals(list(cur), [(1, ), (2, ), (3, )]) + self.assertFalse(self.conn.isexecuting()) def test_copy_while_async(self): - cur = self.conn.cursor() - cur.execute("select 'a'") + with self.conn.cursor() as cur: + cur.execute("select 'a'") - # copy should fail - self.assertRaises(psycopg2.ProgrammingError, - cur.copy_from, - StringIO("1\n3\n5\n\\.\n"), "table1") + # copy should fail + self.assertRaises(psycopg2.ProgrammingError, + cur.copy_from, + StringIO("1\n3\n5\n\\.\n"), "table1") def test_lobject_while_async(self): # large objects should be prohibited @@ -250,68 +249,68 @@ class AsyncTests(ConnectingTestCase): self.conn.lobject) def test_async_executemany(self): - cur = self.conn.cursor() - self.assertRaises( - psycopg2.ProgrammingError, - cur.executemany, "insert into table1 values (%s)", [1, 2, 3]) + with self.conn.cursor() as cur: + self.assertRaises( + psycopg2.ProgrammingError, + cur.executemany, "insert into table1 values (%s)", [1, 2, 3]) def test_async_scroll(self): - cur = self.conn.cursor() - cur.execute(""" - insert into table1 values (1); - insert into table1 values (2); - insert into table1 values (3); - """) - self.wait(cur) - cur.execute("select id from table1 order by id") + with self.conn.cursor() as cur: + cur.execute(""" + insert into table1 values (1); + insert into table1 values (2); + insert into table1 values (3); + """) + self.wait(cur) + cur.execute("select id from table1 order by id") - # scroll should fail if a query is underway - self.assertRaises(psycopg2.ProgrammingError, cur.scroll, 1) - self.assertTrue(self.conn.isexecuting()) + # scroll should fail if a query is underway + self.assertRaises(psycopg2.ProgrammingError, cur.scroll, 1) + self.assertTrue(self.conn.isexecuting()) - # but after it's done it should work - self.wait(cur) - cur.scroll(1) - self.assertEquals(cur.fetchall(), [(2, ), (3, )]) + # but after it's done it should work + self.wait(cur) + cur.scroll(1) + self.assertEquals(cur.fetchall(), [(2, ), (3, )]) - cur = self.conn.cursor() - cur.execute("select id from table1 order by id") - self.wait(cur) + with self.conn.cursor() as cur2: + cur.execute("select id from table1 order by id") + self.wait(cur) - cur2 = self.conn.cursor() - self.assertRaises(psycopg2.ProgrammingError, cur2.scroll, 1) + with self.conn.cursor() as cur2: + self.assertRaises(psycopg2.ProgrammingError, cur2.scroll, 1) - self.assertRaises(psycopg2.ProgrammingError, cur.scroll, 4) + self.assertRaises(psycopg2.ProgrammingError, cur.scroll, 4) - cur = self.conn.cursor() - cur.execute("select id from table1 order by id") - self.wait(cur) - cur.scroll(2) - cur.scroll(-1) - self.assertEquals(cur.fetchall(), [(2, ), (3, )]) + with self.conn.cursor() as cur: + cur.execute("select id from table1 order by id") + self.wait(cur) + cur.scroll(2) + cur.scroll(-1) + self.assertEquals(cur.fetchall(), [(2, ), (3, )]) def test_scroll(self): - cur = self.sync_conn.cursor() - cur.execute("create table table1 (id int)") - cur.execute(""" - insert into table1 values (1); - insert into table1 values (2); - insert into table1 values (3); - """) - cur.execute("select id from table1 order by id") - cur.scroll(2) - cur.scroll(-1) - self.assertEquals(cur.fetchall(), [(2, ), (3, )]) + with self.sync_conn.cursor() as cur: + cur.execute("create table table1 (id int)") + cur.execute(""" + insert into table1 values (1); + insert into table1 values (2); + insert into table1 values (3); + """) + cur.execute("select id from table1 order by id") + cur.scroll(2) + cur.scroll(-1) + self.assertEquals(cur.fetchall(), [(2, ), (3, )]) def test_async_dont_read_all(self): - cur = self.conn.cursor() - cur.execute("select repeat('a', 10000); select repeat('b', 10000)") + with self.conn.cursor() as cur: + cur.execute("select repeat('a', 10000); select repeat('b', 10000)") - # fetch the result - self.wait(cur) + # fetch the result + self.wait(cur) - # it should be the result of the second query - self.assertEquals(cur.fetchone()[0], "b" * 10000) + # it should be the result of the second query + self.assertEquals(cur.fetchone()[0], "b" * 10000) def test_async_subclass(self): class MyConn(ext.connection): @@ -326,15 +325,15 @@ class AsyncTests(ConnectingTestCase): @slow def test_flush_on_write(self): # a very large query requires a flush loop to be sent to the backend - curs = self.conn.cursor() - for mb in 1, 5, 10, 20, 50: - size = mb * 1024 * 1024 - stub = PollableStub(self.conn) - curs.execute("select %s;", ('x' * size,)) - self.wait(stub) - self.assertEqual(size, len(curs.fetchone()[0])) - if stub.polls.count(ext.POLL_WRITE) > 1: - return + with self.conn.cursor() as curs: + for mb in 1, 5, 10, 20, 50: + size = mb * 1024 * 1024 + stub = PollableStub(self.conn) + curs.execute("select %s;", ('x' * size,)) + self.wait(stub) + self.assertEqual(size, len(curs.fetchone()[0])) + if stub.polls.count(ext.POLL_WRITE) > 1: + return # This is more a testing glitch than an error: it happens # on high load on linux: probably because the kernel has more @@ -343,112 +342,108 @@ class AsyncTests(ConnectingTestCase): warnings.warn("sending a large query didn't trigger block on write.") def test_sync_poll(self): - cur = self.sync_conn.cursor() - cur.execute("select 1") - # polling with a sync query works - cur.connection.poll() - self.assertEquals(cur.fetchone()[0], 1) + with self.sync_conn.cursor() as cur: + cur.execute("select 1") + # polling with a sync query works + cur.connection.poll() + self.assertEquals(cur.fetchone()[0], 1) @slow def test_notify(self): - cur = self.conn.cursor() - sync_cur = self.sync_conn.cursor() + with self.conn.cursor() as cur, self.sync_conn.cursor() as sync_cur: + sync_cur.execute("listen test_notify") + self.sync_conn.commit() + cur.execute("notify test_notify") + self.wait(cur) - sync_cur.execute("listen test_notify") - self.sync_conn.commit() - cur.execute("notify test_notify") - self.wait(cur) + self.assertEquals(self.sync_conn.notifies, []) - self.assertEquals(self.sync_conn.notifies, []) - - pid = self.conn.info.backend_pid - for _ in range(5): - self.wait(self.sync_conn) - if not self.sync_conn.notifies: - time.sleep(0.5) - continue - self.assertEquals(len(self.sync_conn.notifies), 1) - self.assertEquals(self.sync_conn.notifies.pop(), - (pid, "test_notify")) - return + pid = self.conn.info.backend_pid + for _ in range(5): + self.wait(self.sync_conn) + if not self.sync_conn.notifies: + time.sleep(0.5) + continue + self.assertEquals(len(self.sync_conn.notifies), 1) + self.assertEquals(self.sync_conn.notifies.pop(), + (pid, "test_notify")) + return self.fail("No NOTIFY in 2.5 seconds") def test_async_fetch_wrong_cursor(self): - cur1 = self.conn.cursor() - cur2 = self.conn.cursor() - cur1.execute("select 1") + with self.conn.cursor() as cur1, self.conn.cursor() as cur2: + cur1.execute("select 1") - self.wait(cur1) - self.assertFalse(self.conn.isexecuting()) - # fetching from a cursor with no results is an error - self.assertRaises(psycopg2.ProgrammingError, cur2.fetchone) - # fetching from the correct cursor works - self.assertEquals(cur1.fetchone()[0], 1) + self.wait(cur1) + self.assertFalse(self.conn.isexecuting()) + # fetching from a cursor with no results is an error + self.assertRaises(psycopg2.ProgrammingError, cur2.fetchone) + # fetching from the correct cursor works + self.assertEquals(cur1.fetchone()[0], 1) def test_error(self): - cur = self.conn.cursor() - cur.execute("insert into table1 values (%s)", (1, )) - self.wait(cur) - cur.execute("insert into table1 values (%s)", (1, )) - # this should fail - self.assertRaises(psycopg2.IntegrityError, self.wait, cur) - cur.execute("insert into table1 values (%s); " - "insert into table1 values (%s)", (2, 2)) - # this should fail as well - self.assertRaises(psycopg2.IntegrityError, self.wait, cur) - # but this should work - cur.execute("insert into table1 values (%s)", (2, )) - self.wait(cur) - # and the cursor should be usable afterwards - cur.execute("insert into table1 values (%s)", (3, )) - self.wait(cur) - cur.execute("select * from table1 order by id") - self.wait(cur) - self.assertEquals(cur.fetchall(), [(1, ), (2, ), (3, )]) - cur.execute("delete from table1") - self.wait(cur) + with self.conn.cursor() as cur: + cur.execute("insert into table1 values (%s)", (1, )) + self.wait(cur) + cur.execute("insert into table1 values (%s)", (1, )) + # this should fail + self.assertRaises(psycopg2.IntegrityError, self.wait, cur) + cur.execute("insert into table1 values (%s); " + "insert into table1 values (%s)", (2, 2)) + # this should fail as well + self.assertRaises(psycopg2.IntegrityError, self.wait, cur) + # but this should work + cur.execute("insert into table1 values (%s)", (2, )) + self.wait(cur) + # and the cursor should be usable afterwards + cur.execute("insert into table1 values (%s)", (3, )) + self.wait(cur) + cur.execute("select * from table1 order by id") + self.wait(cur) + self.assertEquals(cur.fetchall(), [(1, ), (2, ), (3, )]) + cur.execute("delete from table1") + self.wait(cur) def test_stop_on_first_error(self): - cur = self.conn.cursor() - cur.execute("select 1; select x; select 1/0; select 2") - self.assertRaises(psycopg2.errors.UndefinedColumn, self.wait, cur) + with self.conn.cursor() as cur: + cur.execute("select 1; select x; select 1/0; select 2") + self.assertRaises(psycopg2.errors.UndefinedColumn, self.wait, cur) - cur.execute("select 1") - self.wait(cur) - self.assertEqual(cur.fetchone(), (1,)) + cur.execute("select 1") + self.wait(cur) + self.assertEqual(cur.fetchone(), (1,)) def test_error_two_cursors(self): - cur = self.conn.cursor() - cur2 = self.conn.cursor() - cur.execute("select * from no_such_table") - self.assertRaises(psycopg2.ProgrammingError, self.wait, cur) - cur2.execute("select 1") - self.wait(cur2) - self.assertEquals(cur2.fetchone()[0], 1) + with self.conn.cursor() as cur, self.conn.cursor() as cur2: + cur.execute("select * from no_such_table") + self.assertRaises(psycopg2.ProgrammingError, self.wait, cur) + cur2.execute("select 1") + self.wait(cur2) + self.assertEquals(cur2.fetchone()[0], 1) def test_notices(self): del self.conn.notices[:] - cur = self.conn.cursor() - if self.conn.info.server_version >= 90300: - cur.execute("set client_min_messages=debug1") + with self.conn.cursor() as cur: + if self.conn.info.server_version >= 90300: + cur.execute("set client_min_messages=debug1") + self.wait(cur) + cur.execute("create temp table chatty (id serial primary key);") self.wait(cur) - cur.execute("create temp table chatty (id serial primary key);") - self.wait(cur) - self.assertEqual("CREATE TABLE", cur.statusmessage) - self.assert_(self.conn.notices) + self.assertEqual("CREATE TABLE", cur.statusmessage) + self.assert_(self.conn.notices) def test_async_cursor_gone(self): - cur = self.conn.cursor() - cur.execute("select 42;") + with self.conn.cursor() as cur: + cur.execute("select 42;") del cur gc.collect() self.assertRaises(psycopg2.InterfaceError, self.wait, self.conn) # The connection is still usable - cur = self.conn.cursor() - cur.execute("select 42;") - self.wait(self.conn) - self.assertEqual(cur.fetchone(), (42,)) + with self.conn.cursor() as cur: + cur.execute("select 42;") + self.wait(self.conn) + self.assertEqual(cur.fetchone(), (42,)) def test_async_connection_error_message(self): cnn = psycopg2.connect('dbname=thisdatabasedoesntexist', async_=True) @@ -464,40 +459,40 @@ class AsyncTests(ConnectingTestCase): @skip_before_postgres(8, 2) def test_copy_no_hang(self): - cur = self.conn.cursor() - cur.execute("copy (select 1) to stdout") - self.assertRaises(psycopg2.ProgrammingError, self.wait, self.conn) + with self.conn.cursor() as cur: + cur.execute("copy (select 1) to stdout") + self.assertRaises(psycopg2.ProgrammingError, self.wait, self.conn) @slow @skip_before_postgres(9, 0) def test_non_block_after_notification(self): from select import select - cur = self.conn.cursor() - cur.execute(""" - select 1; - do $$ - begin - raise notice 'hello'; - end - $$ language plpgsql; - select pg_sleep(1); - """) + with self.conn.cursor() as cur: + cur.execute(""" + select 1; + do $$ + begin + raise notice 'hello'; + end + $$ language plpgsql; + select pg_sleep(1); + """) - polls = 0 - while True: - state = self.conn.poll() - if state == psycopg2.extensions.POLL_OK: - break - elif state == psycopg2.extensions.POLL_READ: - select([self.conn], [], [], 0.1) - elif state == psycopg2.extensions.POLL_WRITE: - select([], [self.conn], [], 0.1) - else: - raise Exception("Unexpected result from poll: %r", state) - polls += 1 + polls = 0 + while True: + state = self.conn.poll() + if state == psycopg2.extensions.POLL_OK: + break + elif state == psycopg2.extensions.POLL_READ: + select([self.conn], [], [], 0.1) + elif state == psycopg2.extensions.POLL_WRITE: + select([], [self.conn], [], 0.1) + else: + raise Exception("Unexpected result from poll: %r", state) + polls += 1 - self.assert_(polls >= 8, polls) + self.assert_(polls >= 8, polls) def test_poll_noop(self): self.conn.poll() diff --git a/tests/test_async_keyword.py b/tests/test_async_keyword.py index f8e50afe..dbedec01 100755 --- a/tests/test_async_keyword.py +++ b/tests/test_async_keyword.py @@ -47,16 +47,18 @@ class AsyncTests(ConnectingTestCase): self.wait(self.conn) - curs = self.conn.cursor() - curs.execute(''' - CREATE TEMPORARY TABLE table1 ( - id int PRIMARY KEY - )''') - self.wait(curs) + with self.conn.cursor() as curs: + curs.execute(''' + CREATE TEMPORARY TABLE table1 ( + id int PRIMARY KEY + )''') + self.wait(curs) def test_connection_setup(self): cur = self.conn.cursor() sync_cur = self.sync_conn.cursor() + cur.close() + sync_cur.close() del cur, sync_cur self.assert_(self.conn.async) @@ -97,11 +99,11 @@ class CancelTests(ConnectingTestCase): def setUp(self): ConnectingTestCase.setUp(self) - cur = self.conn.cursor() - cur.execute(''' - CREATE TEMPORARY TABLE table1 ( - id int PRIMARY KEY - )''') + with self.conn.cursor() as cur: + cur.execute(''' + CREATE TEMPORARY TABLE table1 ( + id int PRIMARY KEY + )''') self.conn.commit() @slow @@ -110,16 +112,16 @@ class CancelTests(ConnectingTestCase): async_conn = psycopg2.connect(dsn, async=True) self.assertRaises(psycopg2.OperationalError, async_conn.cancel) extras.wait_select(async_conn) - cur = async_conn.cursor() - cur.execute("select pg_sleep(10)") - time.sleep(1) - self.assertTrue(async_conn.isexecuting()) - async_conn.cancel() - self.assertRaises(psycopg2.extensions.QueryCanceledError, - extras.wait_select, async_conn) - cur.execute("select 1") - extras.wait_select(async_conn) - self.assertEqual(cur.fetchall(), [(1, )]) + with async_conn.cursor() as cur: + cur.execute("select pg_sleep(10)") + time.sleep(1) + self.assertTrue(async_conn.isexecuting()) + async_conn.cancel() + self.assertRaises(psycopg2.extensions.QueryCanceledError, + extras.wait_select, async_conn) + cur.execute("select 1") + extras.wait_select(async_conn) + self.assertEqual(cur.fetchall(), [(1, )]) async_conn.close() def test_async_connection_cancel(self): @@ -183,41 +185,40 @@ class AsyncReplicationTest(ReplicationTestCase): if conn is None: return - cur = conn.cursor() + with conn.cursor() as cur: + self.create_replication_slot(cur, output_plugin='test_decoding') + self.wait(cur) - self.create_replication_slot(cur, output_plugin='test_decoding') - self.wait(cur) + cur.start_replication(self.slot) + self.wait(cur) - cur.start_replication(self.slot) - self.wait(cur) + self.make_replication_events() - self.make_replication_events() + self.msg_count = 0 - self.msg_count = 0 + def consume(msg): + # just check the methods + "%s: %s" % (cur.io_timestamp, repr(msg)) + "%s: %s" % (cur.feedback_timestamp, repr(msg)) - def consume(msg): - # just check the methods - "%s: %s" % (cur.io_timestamp, repr(msg)) - "%s: %s" % (cur.feedback_timestamp, repr(msg)) + self.msg_count += 1 + if self.msg_count > 3: + cur.send_feedback(reply=True) + raise StopReplication() - self.msg_count += 1 - if self.msg_count > 3: - cur.send_feedback(reply=True) - raise StopReplication() + cur.send_feedback(flush_lsn=msg.data_start) - cur.send_feedback(flush_lsn=msg.data_start) + # cannot be used in asynchronous mode + self.assertRaises(psycopg2.ProgrammingError, cur.consume_stream, consume) - # cannot be used in asynchronous mode - self.assertRaises(psycopg2.ProgrammingError, cur.consume_stream, consume) - - def process_stream(): - while True: - msg = cur.read_message() - if msg: - consume(msg) - else: - select([cur], [], []) - self.assertRaises(StopReplication, process_stream) + def process_stream(): + while True: + msg = cur.read_message() + if msg: + consume(msg) + else: + select([cur], [], []) + self.assertRaises(StopReplication, process_stream) def test_suite(): diff --git a/tests/test_bug_gc.py b/tests/test_bug_gc.py index 3d0b789e..ee82336d 100755 --- a/tests/test_bug_gc.py +++ b/tests/test_bug_gc.py @@ -39,9 +39,9 @@ class StolenReferenceTestCase(ConnectingTestCase): return 42 UUID = psycopg2.extensions.new_type((2950,), "UUID", fish) psycopg2.extensions.register_type(UUID, self.conn) - curs = self.conn.cursor() - curs.execute("select 'b5219e01-19ab-4994-b71e-149225dc51e4'::uuid") - curs.fetchone() + with self.conn.cursor() as curs: + curs.execute("select 'b5219e01-19ab-4994-b71e-149225dc51e4'::uuid") + curs.fetchone() def test_suite(): diff --git a/tests/test_cancel.py b/tests/test_cancel.py index 06477edc..b4daba3e 100755 --- a/tests/test_cancel.py +++ b/tests/test_cancel.py @@ -41,11 +41,11 @@ class CancelTests(ConnectingTestCase): def setUp(self): ConnectingTestCase.setUp(self) - cur = self.conn.cursor() - cur.execute(''' - CREATE TEMPORARY TABLE table1 ( - id int PRIMARY KEY - )''') + with self.conn.cursor() as cur: + cur.execute(''' + CREATE TEMPORARY TABLE table1 ( + id int PRIMARY KEY + )''') self.conn.commit() def test_empty_cancel(self): @@ -57,25 +57,25 @@ class CancelTests(ConnectingTestCase): errors = [] def neverending(conn): - cur = conn.cursor() - try: - self.assertRaises(psycopg2.extensions.QueryCanceledError, - cur.execute, "select pg_sleep(60)") - # make sure the connection still works - conn.rollback() - cur.execute("select 1") - self.assertEqual(cur.fetchall(), [(1, )]) - except Exception as e: - errors.append(e) - raise + with conn.cursor() as cur: + try: + self.assertRaises(psycopg2.extensions.QueryCanceledError, + cur.execute, "select pg_sleep(60)") + # make sure the connection still works + conn.rollback() + cur.execute("select 1") + self.assertEqual(cur.fetchall(), [(1, )]) + except Exception as e: + errors.append(e) + raise def canceller(conn): - cur = conn.cursor() - try: - conn.cancel() - except Exception as e: - errors.append(e) - raise + with conn.cursor() as cur: + try: + conn.cancel() + except Exception as e: + errors.append(e) + raise del cur thread1 = threading.Thread(target=neverending, args=(self.conn, )) @@ -95,16 +95,16 @@ class CancelTests(ConnectingTestCase): async_conn = psycopg2.connect(dsn, async_=True) self.assertRaises(psycopg2.OperationalError, async_conn.cancel) extras.wait_select(async_conn) - cur = async_conn.cursor() - cur.execute("select pg_sleep(10)") - time.sleep(1) - self.assertTrue(async_conn.isexecuting()) - async_conn.cancel() - self.assertRaises(psycopg2.extensions.QueryCanceledError, - extras.wait_select, async_conn) - cur.execute("select 1") - extras.wait_select(async_conn) - self.assertEqual(cur.fetchall(), [(1, )]) + with async_conn.cursor() as cur: + cur.execute("select pg_sleep(10)") + time.sleep(1) + self.assertTrue(async_conn.isexecuting()) + async_conn.cancel() + self.assertRaises(psycopg2.extensions.QueryCanceledError, + extras.wait_select, async_conn) + cur.execute("select 1") + extras.wait_select(async_conn) + self.assertEqual(cur.fetchall(), [(1, )]) async_conn.close() def test_async_connection_cancel(self): diff --git a/tests/test_connection.py b/tests/test_connection.py index f1668d55..c724a7fa 100755 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -80,9 +80,9 @@ class ConnectionTests(ConnectingTestCase): def test_cleanup_on_badconn_close(self): # ticket #148 conn = self.conn - cur = conn.cursor() - self.assertRaises(psycopg2.OperationalError, - cur.execute, "select pg_terminate_backend(pg_backend_pid())") + with conn.cursor() as cur: + self.assertRaises(psycopg2.OperationalError, + cur.execute, "select pg_terminate_backend(pg_backend_pid())") self.assertEqual(conn.closed, 2) conn.close() @@ -113,87 +113,87 @@ class ConnectionTests(ConnectingTestCase): def test_notices(self): conn = self.conn - cur = conn.cursor() - if self.conn.info.server_version >= 90300: - cur.execute("set client_min_messages=debug1") - cur.execute("create temp table chatty (id serial primary key);") - self.assertEqual("CREATE TABLE", cur.statusmessage) - self.assert_(conn.notices) + with conn.cursor() as cur: + if self.conn.info.server_version >= 90300: + cur.execute("set client_min_messages=debug1") + cur.execute("create temp table chatty (id serial primary key);") + self.assertEqual("CREATE TABLE", cur.statusmessage) + self.assert_(conn.notices) def test_notices_consistent_order(self): conn = self.conn - cur = conn.cursor() - if self.conn.info.server_version >= 90300: - cur.execute("set client_min_messages=debug1") - cur.execute(""" - create temp table table1 (id serial); - create temp table table2 (id serial); - """) - cur.execute(""" - create temp table table3 (id serial); - create temp table table4 (id serial); - """) - self.assertEqual(4, len(conn.notices)) - self.assert_('table1' in conn.notices[0]) - self.assert_('table2' in conn.notices[1]) - self.assert_('table3' in conn.notices[2]) - self.assert_('table4' in conn.notices[3]) + with conn.cursor() as cur: + if self.conn.info.server_version >= 90300: + cur.execute("set client_min_messages=debug1") + cur.execute(""" + create temp table table1 (id serial); + create temp table table2 (id serial); + """) + cur.execute(""" + create temp table table3 (id serial); + create temp table table4 (id serial); + """) + self.assertEqual(4, len(conn.notices)) + self.assert_('table1' in conn.notices[0]) + self.assert_('table2' in conn.notices[1]) + self.assert_('table3' in conn.notices[2]) + self.assert_('table4' in conn.notices[3]) @slow def test_notices_limited(self): conn = self.conn - cur = conn.cursor() - if self.conn.info.server_version >= 90300: - cur.execute("set client_min_messages=debug1") - for i in range(0, 100, 10): - sql = " ".join(["create temp table table%d (id serial);" % j - for j in range(i, i + 10)]) - cur.execute(sql) + with conn.cursor() as cur: + if self.conn.info.server_version >= 90300: + cur.execute("set client_min_messages=debug1") + for i in range(0, 100, 10): + sql = " ".join(["create temp table table%d (id serial);" % j + for j in range(i, i + 10)]) + cur.execute(sql) - self.assertEqual(50, len(conn.notices)) - self.assert_('table99' in conn.notices[-1], conn.notices[-1]) + self.assertEqual(50, len(conn.notices)) + self.assert_('table99' in conn.notices[-1], conn.notices[-1]) @slow def test_notices_deque(self): conn = self.conn self.conn.notices = deque() - cur = conn.cursor() - if self.conn.info.server_version >= 90300: - cur.execute("set client_min_messages=debug1") + with conn.cursor() as cur: + if self.conn.info.server_version >= 90300: + cur.execute("set client_min_messages=debug1") - cur.execute(""" - create temp table table1 (id serial); - create temp table table2 (id serial); - """) - cur.execute(""" - create temp table table3 (id serial); - create temp table table4 (id serial);""") - self.assertEqual(len(conn.notices), 4) - self.assert_('table1' in conn.notices.popleft()) - self.assert_('table2' in conn.notices.popleft()) - self.assert_('table3' in conn.notices.popleft()) - self.assert_('table4' in conn.notices.popleft()) - self.assertEqual(len(conn.notices), 0) + cur.execute(""" + create temp table table1 (id serial); + create temp table table2 (id serial); + """) + cur.execute(""" + create temp table table3 (id serial); + create temp table table4 (id serial);""") + self.assertEqual(len(conn.notices), 4) + self.assert_('table1' in conn.notices.popleft()) + self.assert_('table2' in conn.notices.popleft()) + self.assert_('table3' in conn.notices.popleft()) + self.assert_('table4' in conn.notices.popleft()) + self.assertEqual(len(conn.notices), 0) - # not limited, but no error - for i in range(0, 100, 10): - sql = " ".join(["create temp table table2_%d (id serial);" % j - for j in range(i, i + 10)]) - cur.execute(sql) + # not limited, but no error + for i in range(0, 100, 10): + sql = " ".join(["create temp table table2_%d (id serial);" % j + for j in range(i, i + 10)]) + cur.execute(sql) - self.assertEqual(len([n for n in conn.notices if 'CREATE TABLE' in n]), - 100) + self.assertEqual(len([n for n in conn.notices if 'CREATE TABLE' in n]), + 100) def test_notices_noappend(self): conn = self.conn self.conn.notices = None # will make an error swallowes ok - cur = conn.cursor() - if self.conn.info.server_version >= 90300: - cur.execute("set client_min_messages=debug1") + with conn.cursor() as cur: + if self.conn.info.server_version >= 90300: + cur.execute("set client_min_messages=debug1") - cur.execute("create temp table table1 (id serial);") + cur.execute("create temp table table1 (id serial);") - self.assertEqual(self.conn.notices, None) + self.assertEqual(self.conn.notices, None) def test_server_version(self): self.assert_(self.conn.server_version) @@ -233,10 +233,10 @@ class ConnectionTests(ConnectingTestCase): def test_encoding_name(self): self.conn.set_client_encoding("EUC_JP") # conn.encoding is 'EUCJP' now. - cur = self.conn.cursor() - ext.register_type(ext.UNICODE, cur) - cur.execute("select 'foo'::text;") - self.assertEqual(cur.fetchone()[0], u'foo') + with self.conn.cursor() as cur: + ext.register_type(ext.UNICODE, cur) + cur.execute("select 'foo'::text;") + self.assertEqual(cur.fetchone()[0], u'foo') def test_connect_nonnormal_envvar(self): # We must perform encoding normalization at connection time @@ -281,14 +281,14 @@ class ConnectionTests(ConnectingTestCase): while conn.notices: notices.append((2, conn.notices.pop())) - cur = conn.cursor() - t1 = threading.Thread(target=committer) - t1.start() - for i in range(1000): - cur.execute("select %s;", (i,)) - conn.commit() - while conn.notices: - notices.append((1, conn.notices.pop())) + with conn.cursor() as cur: + t1 = threading.Thread(target=committer) + t1.start() + for i in range(1000): + cur.execute("select %s;", (i,)) + conn.commit() + while conn.notices: + notices.append((1, conn.notices.pop())) # Stop the committer thread stop.append(True) @@ -297,37 +297,37 @@ class ConnectionTests(ConnectingTestCase): def test_connect_cursor_factory(self): conn = self.connect(cursor_factory=psycopg2.extras.DictCursor) - cur = conn.cursor() - cur.execute("select 1 as a") - self.assertEqual(cur.fetchone()['a'], 1) + with conn.cursor() as cur: + cur.execute("select 1 as a") + self.assertEqual(cur.fetchone()['a'], 1) def test_cursor_factory(self): self.assertEqual(self.conn.cursor_factory, None) - cur = self.conn.cursor() - cur.execute("select 1 as a") - self.assertRaises(TypeError, (lambda r: r['a']), cur.fetchone()) + with self.conn.cursor() as cur: + cur.execute("select 1 as a") + self.assertRaises(TypeError, (lambda r: r['a']), cur.fetchone()) self.conn.cursor_factory = psycopg2.extras.DictCursor self.assertEqual(self.conn.cursor_factory, psycopg2.extras.DictCursor) - cur = self.conn.cursor() - cur.execute("select 1 as a") - self.assertEqual(cur.fetchone()['a'], 1) + with self.conn.cursor() as cur: + cur.execute("select 1 as a") + self.assertEqual(cur.fetchone()['a'], 1) self.conn.cursor_factory = None self.assertEqual(self.conn.cursor_factory, None) - cur = self.conn.cursor() - cur.execute("select 1 as a") - self.assertRaises(TypeError, (lambda r: r['a']), cur.fetchone()) + with self.conn.cursor() as cur: + cur.execute("select 1 as a") + self.assertRaises(TypeError, (lambda r: r['a']), cur.fetchone()) def test_cursor_factory_none(self): # issue #210 conn = self.connect() - cur = conn.cursor(cursor_factory=None) - self.assertEqual(type(cur), ext.cursor) + with conn.cursor(cursor_factory=None) as cur: + self.assertEqual(type(cur), ext.cursor) conn = self.connect(cursor_factory=psycopg2.extras.DictCursor) - cur = conn.cursor(cursor_factory=None) - self.assertEqual(type(cur), psycopg2.extras.DictCursor) + with conn.cursor(cursor_factory=None) as cur: + self.assertEqual(type(cur), psycopg2.extras.DictCursor) def test_failed_init_status(self): class SubConnection(ext.connection): @@ -583,179 +583,175 @@ class IsolationLevelsTestCase(ConnectingTestCase): def test_set_isolation_level(self): conn = self.connect() - curs = conn.cursor() + with conn.cursor() as curs: + levels = [ + ('read uncommitted', + ext.ISOLATION_LEVEL_READ_UNCOMMITTED), + ('read committed', ext.ISOLATION_LEVEL_READ_COMMITTED), + ('repeatable read', ext.ISOLATION_LEVEL_REPEATABLE_READ), + ('serializable', ext.ISOLATION_LEVEL_SERIALIZABLE), + ] + for name, level in levels: + conn.set_isolation_level(level) - levels = [ - ('read uncommitted', - ext.ISOLATION_LEVEL_READ_UNCOMMITTED), - ('read committed', ext.ISOLATION_LEVEL_READ_COMMITTED), - ('repeatable read', ext.ISOLATION_LEVEL_REPEATABLE_READ), - ('serializable', ext.ISOLATION_LEVEL_SERIALIZABLE), - ] - for name, level in levels: - conn.set_isolation_level(level) + # the only values available on prehistoric PG versions + if conn.info.server_version < 80000: + if level in ( + ext.ISOLATION_LEVEL_READ_UNCOMMITTED, + ext.ISOLATION_LEVEL_REPEATABLE_READ): + name, level = levels[levels.index((name, level)) + 1] - # the only values available on prehistoric PG versions - if conn.info.server_version < 80000: - if level in ( - ext.ISOLATION_LEVEL_READ_UNCOMMITTED, - ext.ISOLATION_LEVEL_REPEATABLE_READ): - name, level = levels[levels.index((name, level)) + 1] + self.assertEqual(conn.isolation_level, level) - self.assertEqual(conn.isolation_level, level) + curs.execute('show transaction_isolation;') + got_name = curs.fetchone()[0] - curs.execute('show transaction_isolation;') - got_name = curs.fetchone()[0] + self.assertEqual(name, got_name) + conn.commit() - self.assertEqual(name, got_name) - conn.commit() - - self.assertRaises(ValueError, conn.set_isolation_level, -1) - self.assertRaises(ValueError, conn.set_isolation_level, 5) + self.assertRaises(ValueError, conn.set_isolation_level, -1) + self.assertRaises(ValueError, conn.set_isolation_level, 5) def test_set_isolation_level_autocommit(self): conn = self.connect() - curs = conn.cursor() + with conn.cursor() as curs: + conn.set_isolation_level(ext.ISOLATION_LEVEL_AUTOCOMMIT) + self.assertEqual(conn.isolation_level, ext.ISOLATION_LEVEL_DEFAULT) + self.assert_(conn.autocommit) - conn.set_isolation_level(ext.ISOLATION_LEVEL_AUTOCOMMIT) - self.assertEqual(conn.isolation_level, ext.ISOLATION_LEVEL_DEFAULT) - self.assert_(conn.autocommit) + conn.isolation_level = 'serializable' + self.assertEqual(conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE) + self.assert_(conn.autocommit) - conn.isolation_level = 'serializable' - self.assertEqual(conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE) - self.assert_(conn.autocommit) - - curs.execute('show transaction_isolation;') - self.assertEqual(curs.fetchone()[0], 'serializable') + curs.execute('show transaction_isolation;') + self.assertEqual(curs.fetchone()[0], 'serializable') def test_set_isolation_level_default(self): conn = self.connect() - curs = conn.cursor() + with conn.cursor() as curs: + conn.autocommit = True + curs.execute("set default_transaction_isolation to 'read committed'") - conn.autocommit = True - curs.execute("set default_transaction_isolation to 'read committed'") + conn.autocommit = False + conn.set_isolation_level(ext.ISOLATION_LEVEL_SERIALIZABLE) + self.assertEqual(conn.isolation_level, + ext.ISOLATION_LEVEL_SERIALIZABLE) + curs.execute("show transaction_isolation") + self.assertEqual(curs.fetchone()[0], "serializable") - conn.autocommit = False - conn.set_isolation_level(ext.ISOLATION_LEVEL_SERIALIZABLE) - self.assertEqual(conn.isolation_level, - ext.ISOLATION_LEVEL_SERIALIZABLE) - curs.execute("show transaction_isolation") - self.assertEqual(curs.fetchone()[0], "serializable") - - conn.rollback() - conn.set_isolation_level(ext.ISOLATION_LEVEL_DEFAULT) - curs.execute("show transaction_isolation") - self.assertEqual(curs.fetchone()[0], "read committed") + conn.rollback() + conn.set_isolation_level(ext.ISOLATION_LEVEL_DEFAULT) + curs.execute("show transaction_isolation") + self.assertEqual(curs.fetchone()[0], "read committed") def test_set_isolation_level_abort(self): conn = self.connect() - cur = conn.cursor() + with conn.cursor() as cur: + self.assertEqual(ext.TRANSACTION_STATUS_IDLE, + conn.info.transaction_status) + cur.execute("insert into isolevel values (10);") + self.assertEqual(ext.TRANSACTION_STATUS_INTRANS, + conn.info.transaction_status) - self.assertEqual(ext.TRANSACTION_STATUS_IDLE, - conn.info.transaction_status) - cur.execute("insert into isolevel values (10);") - self.assertEqual(ext.TRANSACTION_STATUS_INTRANS, - conn.info.transaction_status) + conn.set_isolation_level( + psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE) + self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE, + conn.info.transaction_status) + cur.execute("select count(*) from isolevel;") + self.assertEqual(0, cur.fetchone()[0]) - conn.set_isolation_level( - psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE) - self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE, - conn.info.transaction_status) - cur.execute("select count(*) from isolevel;") - self.assertEqual(0, cur.fetchone()[0]) + cur.execute("insert into isolevel values (10);") + self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_INTRANS, + conn.info.transaction_status) + conn.set_isolation_level( + psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) + self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE, + conn.info.transaction_status) + cur.execute("select count(*) from isolevel;") + self.assertEqual(0, cur.fetchone()[0]) - cur.execute("insert into isolevel values (10);") - self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_INTRANS, - conn.info.transaction_status) - conn.set_isolation_level( - psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) - self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE, - conn.info.transaction_status) - cur.execute("select count(*) from isolevel;") - self.assertEqual(0, cur.fetchone()[0]) - - cur.execute("insert into isolevel values (10);") - self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE, - conn.info.transaction_status) - conn.set_isolation_level( - psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED) - self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE, - conn.info.transaction_status) - cur.execute("select count(*) from isolevel;") - self.assertEqual(1, cur.fetchone()[0]) - self.assertEqual(conn.isolation_level, - psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED) + cur.execute("insert into isolevel values (10);") + self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE, + conn.info.transaction_status) + conn.set_isolation_level( + psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED) + self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE, + conn.info.transaction_status) + cur.execute("select count(*) from isolevel;") + self.assertEqual(1, cur.fetchone()[0]) + self.assertEqual(conn.isolation_level, + psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED) def test_isolation_level_autocommit(self): cnn1 = self.connect() cnn2 = self.connect() cnn2.set_isolation_level(ext.ISOLATION_LEVEL_AUTOCOMMIT) - cur1 = cnn1.cursor() - cur1.execute("select count(*) from isolevel;") - self.assertEqual(0, cur1.fetchone()[0]) - cnn1.commit() + with cnn1.cursor() as cur1: + cur1.execute("select count(*) from isolevel;") + self.assertEqual(0, cur1.fetchone()[0]) + cnn1.commit() - cur2 = cnn2.cursor() - cur2.execute("insert into isolevel values (10);") + with cnn2.cursor() as cur2: + cur2.execute("insert into isolevel values (10);") - cur1.execute("select count(*) from isolevel;") - self.assertEqual(1, cur1.fetchone()[0]) + cur1.execute("select count(*) from isolevel;") + self.assertEqual(1, cur1.fetchone()[0]) def test_isolation_level_read_committed(self): cnn1 = self.connect() cnn2 = self.connect() cnn2.set_isolation_level(ext.ISOLATION_LEVEL_READ_COMMITTED) - cur1 = cnn1.cursor() - cur1.execute("select count(*) from isolevel;") - self.assertEqual(0, cur1.fetchone()[0]) - cnn1.commit() + with cnn1.cursor() as cur1: + cur1.execute("select count(*) from isolevel;") + self.assertEqual(0, cur1.fetchone()[0]) + cnn1.commit() - cur2 = cnn2.cursor() - cur2.execute("insert into isolevel values (10);") - cur1.execute("insert into isolevel values (20);") + with cnn2.cursor() as cur2: + cur2.execute("insert into isolevel values (10);") + cur1.execute("insert into isolevel values (20);") - cur2.execute("select count(*) from isolevel;") - self.assertEqual(1, cur2.fetchone()[0]) - cnn1.commit() - cur2.execute("select count(*) from isolevel;") - self.assertEqual(2, cur2.fetchone()[0]) + cur2.execute("select count(*) from isolevel;") + self.assertEqual(1, cur2.fetchone()[0]) + cnn1.commit() + cur2.execute("select count(*) from isolevel;") + self.assertEqual(2, cur2.fetchone()[0]) - cur1.execute("select count(*) from isolevel;") - self.assertEqual(1, cur1.fetchone()[0]) - cnn2.commit() - cur1.execute("select count(*) from isolevel;") - self.assertEqual(2, cur1.fetchone()[0]) + cur1.execute("select count(*) from isolevel;") + self.assertEqual(1, cur1.fetchone()[0]) + cnn2.commit() + cur1.execute("select count(*) from isolevel;") + self.assertEqual(2, cur1.fetchone()[0]) def test_isolation_level_serializable(self): cnn1 = self.connect() cnn2 = self.connect() cnn2.set_isolation_level(ext.ISOLATION_LEVEL_SERIALIZABLE) - cur1 = cnn1.cursor() - cur1.execute("select count(*) from isolevel;") - self.assertEqual(0, cur1.fetchone()[0]) - cnn1.commit() + with cnn1.cursor() as cur1: + cur1.execute("select count(*) from isolevel;") + self.assertEqual(0, cur1.fetchone()[0]) + cnn1.commit() - cur2 = cnn2.cursor() - cur2.execute("insert into isolevel values (10);") - cur1.execute("insert into isolevel values (20);") + with cnn2.cursor() as cur2: + cur2.execute("insert into isolevel values (10);") + cur1.execute("insert into isolevel values (20);") - cur2.execute("select count(*) from isolevel;") - self.assertEqual(1, cur2.fetchone()[0]) - cnn1.commit() - cur2.execute("select count(*) from isolevel;") - self.assertEqual(1, cur2.fetchone()[0]) + cur2.execute("select count(*) from isolevel;") + self.assertEqual(1, cur2.fetchone()[0]) + cnn1.commit() + cur2.execute("select count(*) from isolevel;") + self.assertEqual(1, cur2.fetchone()[0]) - cur1.execute("select count(*) from isolevel;") - self.assertEqual(1, cur1.fetchone()[0]) - cnn2.commit() - cur1.execute("select count(*) from isolevel;") - self.assertEqual(2, cur1.fetchone()[0]) + cur1.execute("select count(*) from isolevel;") + self.assertEqual(1, cur1.fetchone()[0]) + cnn2.commit() + cur1.execute("select count(*) from isolevel;") + self.assertEqual(2, cur1.fetchone()[0]) - cur2.execute("select count(*) from isolevel;") - self.assertEqual(2, cur2.fetchone()[0]) + cur2.execute("select count(*) from isolevel;") + self.assertEqual(2, cur2.fetchone()[0]) def test_isolation_level_closed(self): cnn = self.connect() @@ -766,99 +762,99 @@ class IsolationLevelsTestCase(ConnectingTestCase): cnn.set_isolation_level, 1) def test_setattr_isolation_level_int(self): - cur = self.conn.cursor() - self.conn.isolation_level = ext.ISOLATION_LEVEL_SERIALIZABLE - self.assertEqual(self.conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE) + with self.conn.cursor() as cur: + self.conn.isolation_level = ext.ISOLATION_LEVEL_SERIALIZABLE + self.assertEqual(self.conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE) - cur.execute("SHOW transaction_isolation;") - self.assertEqual(cur.fetchone()[0], 'serializable') - self.conn.rollback() - - self.conn.isolation_level = ext.ISOLATION_LEVEL_REPEATABLE_READ - cur.execute("SHOW transaction_isolation;") - if self.conn.info.server_version > 80000: - self.assertEqual(self.conn.isolation_level, - ext.ISOLATION_LEVEL_REPEATABLE_READ) - self.assertEqual(cur.fetchone()[0], 'repeatable read') - else: - self.assertEqual(self.conn.isolation_level, - ext.ISOLATION_LEVEL_SERIALIZABLE) + cur.execute("SHOW transaction_isolation;") self.assertEqual(cur.fetchone()[0], 'serializable') - self.conn.rollback() + self.conn.rollback() - self.conn.isolation_level = ext.ISOLATION_LEVEL_READ_COMMITTED - self.assertEqual(self.conn.isolation_level, - ext.ISOLATION_LEVEL_READ_COMMITTED) - cur.execute("SHOW transaction_isolation;") - self.assertEqual(cur.fetchone()[0], 'read committed') - self.conn.rollback() + self.conn.isolation_level = ext.ISOLATION_LEVEL_REPEATABLE_READ + cur.execute("SHOW transaction_isolation;") + if self.conn.info.server_version > 80000: + self.assertEqual(self.conn.isolation_level, + ext.ISOLATION_LEVEL_REPEATABLE_READ) + self.assertEqual(cur.fetchone()[0], 'repeatable read') + else: + self.assertEqual(self.conn.isolation_level, + ext.ISOLATION_LEVEL_SERIALIZABLE) + self.assertEqual(cur.fetchone()[0], 'serializable') + self.conn.rollback() - self.conn.isolation_level = ext.ISOLATION_LEVEL_READ_UNCOMMITTED - cur.execute("SHOW transaction_isolation;") - if self.conn.info.server_version > 80000: - self.assertEqual(self.conn.isolation_level, - ext.ISOLATION_LEVEL_READ_UNCOMMITTED) - self.assertEqual(cur.fetchone()[0], 'read uncommitted') - else: + self.conn.isolation_level = ext.ISOLATION_LEVEL_READ_COMMITTED self.assertEqual(self.conn.isolation_level, ext.ISOLATION_LEVEL_READ_COMMITTED) + cur.execute("SHOW transaction_isolation;") self.assertEqual(cur.fetchone()[0], 'read committed') - self.conn.rollback() + self.conn.rollback() - self.assertEqual(ext.ISOLATION_LEVEL_DEFAULT, None) - self.conn.isolation_level = ext.ISOLATION_LEVEL_DEFAULT - self.assertEqual(self.conn.isolation_level, None) - cur.execute("SHOW transaction_isolation;") - isol = cur.fetchone()[0] - cur.execute("SHOW default_transaction_isolation;") - self.assertEqual(cur.fetchone()[0], isol) + self.conn.isolation_level = ext.ISOLATION_LEVEL_READ_UNCOMMITTED + cur.execute("SHOW transaction_isolation;") + if self.conn.info.server_version > 80000: + self.assertEqual(self.conn.isolation_level, + ext.ISOLATION_LEVEL_READ_UNCOMMITTED) + self.assertEqual(cur.fetchone()[0], 'read uncommitted') + else: + self.assertEqual(self.conn.isolation_level, + ext.ISOLATION_LEVEL_READ_COMMITTED) + self.assertEqual(cur.fetchone()[0], 'read committed') + self.conn.rollback() + + self.assertEqual(ext.ISOLATION_LEVEL_DEFAULT, None) + self.conn.isolation_level = ext.ISOLATION_LEVEL_DEFAULT + self.assertEqual(self.conn.isolation_level, None) + cur.execute("SHOW transaction_isolation;") + isol = cur.fetchone()[0] + cur.execute("SHOW default_transaction_isolation;") + self.assertEqual(cur.fetchone()[0], isol) def test_setattr_isolation_level_str(self): - cur = self.conn.cursor() - self.conn.isolation_level = "serializable" - self.assertEqual(self.conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE) + with self.conn.cursor() as cur: + self.conn.isolation_level = "serializable" + self.assertEqual(self.conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE) - cur.execute("SHOW transaction_isolation;") - self.assertEqual(cur.fetchone()[0], 'serializable') - self.conn.rollback() - - self.conn.isolation_level = "repeatable read" - cur.execute("SHOW transaction_isolation;") - if self.conn.info.server_version > 80000: - self.assertEqual(self.conn.isolation_level, - ext.ISOLATION_LEVEL_REPEATABLE_READ) - self.assertEqual(cur.fetchone()[0], 'repeatable read') - else: - self.assertEqual(self.conn.isolation_level, - ext.ISOLATION_LEVEL_SERIALIZABLE) + cur.execute("SHOW transaction_isolation;") self.assertEqual(cur.fetchone()[0], 'serializable') - self.conn.rollback() + self.conn.rollback() - self.conn.isolation_level = "read committed" - self.assertEqual(self.conn.isolation_level, - ext.ISOLATION_LEVEL_READ_COMMITTED) - cur.execute("SHOW transaction_isolation;") - self.assertEqual(cur.fetchone()[0], 'read committed') - self.conn.rollback() + self.conn.isolation_level = "repeatable read" + cur.execute("SHOW transaction_isolation;") + if self.conn.info.server_version > 80000: + self.assertEqual(self.conn.isolation_level, + ext.ISOLATION_LEVEL_REPEATABLE_READ) + self.assertEqual(cur.fetchone()[0], 'repeatable read') + else: + self.assertEqual(self.conn.isolation_level, + ext.ISOLATION_LEVEL_SERIALIZABLE) + self.assertEqual(cur.fetchone()[0], 'serializable') + self.conn.rollback() - self.conn.isolation_level = "read uncommitted" - cur.execute("SHOW transaction_isolation;") - if self.conn.info.server_version > 80000: - self.assertEqual(self.conn.isolation_level, - ext.ISOLATION_LEVEL_READ_UNCOMMITTED) - self.assertEqual(cur.fetchone()[0], 'read uncommitted') - else: + self.conn.isolation_level = "read committed" self.assertEqual(self.conn.isolation_level, ext.ISOLATION_LEVEL_READ_COMMITTED) + cur.execute("SHOW transaction_isolation;") self.assertEqual(cur.fetchone()[0], 'read committed') - self.conn.rollback() + self.conn.rollback() - self.conn.isolation_level = "default" - self.assertEqual(self.conn.isolation_level, None) - cur.execute("SHOW transaction_isolation;") - isol = cur.fetchone()[0] - cur.execute("SHOW default_transaction_isolation;") - self.assertEqual(cur.fetchone()[0], isol) + self.conn.isolation_level = "read uncommitted" + cur.execute("SHOW transaction_isolation;") + if self.conn.info.server_version > 80000: + self.assertEqual(self.conn.isolation_level, + ext.ISOLATION_LEVEL_READ_UNCOMMITTED) + self.assertEqual(cur.fetchone()[0], 'read uncommitted') + else: + self.assertEqual(self.conn.isolation_level, + ext.ISOLATION_LEVEL_READ_COMMITTED) + self.assertEqual(cur.fetchone()[0], 'read committed') + self.conn.rollback() + + self.conn.isolation_level = "default" + self.assertEqual(self.conn.isolation_level, None) + cur.execute("SHOW transaction_isolation;") + isol = cur.fetchone()[0] + cur.execute("SHOW default_transaction_isolation;") + self.assertEqual(cur.fetchone()[0], isol) def test_setattr_isolation_level_invalid(self): self.assertRaises(ValueError, setattr, self.conn, 'isolation_level', 0) @@ -946,20 +942,20 @@ class ConnectionTwoPhaseTests(ConnectingTestCase): cnn.tpc_begin(xid) self.assertEqual(cnn.status, ext.STATUS_BEGIN) - cur = cnn.cursor() - cur.execute("insert into test_tpc values ('test_tpc_commit');") - self.assertEqual(0, self.count_xacts()) - self.assertEqual(0, self.count_test_records()) + with cnn.cursor() as cur: + cur.execute("insert into test_tpc values ('test_tpc_commit');") + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) - cnn.tpc_prepare() - self.assertEqual(cnn.status, ext.STATUS_PREPARED) - self.assertEqual(1, self.count_xacts()) - self.assertEqual(0, self.count_test_records()) + cnn.tpc_prepare() + self.assertEqual(cnn.status, ext.STATUS_PREPARED) + self.assertEqual(1, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) - cnn.tpc_commit() - self.assertEqual(cnn.status, ext.STATUS_READY) - self.assertEqual(0, self.count_xacts()) - self.assertEqual(1, self.count_test_records()) + cnn.tpc_commit() + self.assertEqual(cnn.status, ext.STATUS_READY) + self.assertEqual(0, self.count_xacts()) + self.assertEqual(1, self.count_test_records()) def test_tpc_commit_one_phase(self): cnn = self.connect() @@ -969,15 +965,15 @@ class ConnectionTwoPhaseTests(ConnectingTestCase): cnn.tpc_begin(xid) self.assertEqual(cnn.status, ext.STATUS_BEGIN) - cur = cnn.cursor() - cur.execute("insert into test_tpc values ('test_tpc_commit_1p');") - self.assertEqual(0, self.count_xacts()) - self.assertEqual(0, self.count_test_records()) + with cnn.cursor() as cur: + cur.execute("insert into test_tpc values ('test_tpc_commit_1p');") + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) - cnn.tpc_commit() - self.assertEqual(cnn.status, ext.STATUS_READY) - self.assertEqual(0, self.count_xacts()) - self.assertEqual(1, self.count_test_records()) + cnn.tpc_commit() + self.assertEqual(cnn.status, ext.STATUS_READY) + self.assertEqual(0, self.count_xacts()) + self.assertEqual(1, self.count_test_records()) def test_tpc_commit_recovered(self): cnn = self.connect() @@ -1013,20 +1009,20 @@ class ConnectionTwoPhaseTests(ConnectingTestCase): cnn.tpc_begin(xid) self.assertEqual(cnn.status, ext.STATUS_BEGIN) - cur = cnn.cursor() - cur.execute("insert into test_tpc values ('test_tpc_rollback');") - self.assertEqual(0, self.count_xacts()) - self.assertEqual(0, self.count_test_records()) + with cnn.cursor() as cur: + cur.execute("insert into test_tpc values ('test_tpc_rollback');") + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) - cnn.tpc_prepare() - self.assertEqual(cnn.status, ext.STATUS_PREPARED) - self.assertEqual(1, self.count_xacts()) - self.assertEqual(0, self.count_test_records()) + cnn.tpc_prepare() + self.assertEqual(cnn.status, ext.STATUS_PREPARED) + self.assertEqual(1, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) - cnn.tpc_rollback() - self.assertEqual(cnn.status, ext.STATUS_READY) - self.assertEqual(0, self.count_xacts()) - self.assertEqual(0, self.count_test_records()) + cnn.tpc_rollback() + self.assertEqual(cnn.status, ext.STATUS_READY) + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) def test_tpc_rollback_one_phase(self): cnn = self.connect() @@ -1036,15 +1032,15 @@ class ConnectionTwoPhaseTests(ConnectingTestCase): cnn.tpc_begin(xid) self.assertEqual(cnn.status, ext.STATUS_BEGIN) - cur = cnn.cursor() - cur.execute("insert into test_tpc values ('test_tpc_rollback_1p');") - self.assertEqual(0, self.count_xacts()) - self.assertEqual(0, self.count_test_records()) + with cnn.cursor() as cur: + cur.execute("insert into test_tpc values ('test_tpc_rollback_1p');") + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) - cnn.tpc_rollback() - self.assertEqual(cnn.status, ext.STATUS_READY) - self.assertEqual(0, self.count_xacts()) - self.assertEqual(0, self.count_test_records()) + cnn.tpc_rollback() + self.assertEqual(cnn.status, ext.STATUS_READY) + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) def test_tpc_rollback_recovered(self): cnn = self.connect() @@ -1054,23 +1050,23 @@ class ConnectionTwoPhaseTests(ConnectingTestCase): cnn.tpc_begin(xid) self.assertEqual(cnn.status, ext.STATUS_BEGIN) - cur = cnn.cursor() - cur.execute("insert into test_tpc values ('test_tpc_commit_rec');") - self.assertEqual(0, self.count_xacts()) - self.assertEqual(0, self.count_test_records()) + with cnn.cursor() as cur: + cur.execute("insert into test_tpc values ('test_tpc_commit_rec');") + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) - cnn.tpc_prepare() - cnn.close() - self.assertEqual(1, self.count_xacts()) - self.assertEqual(0, self.count_test_records()) + cnn.tpc_prepare() + cnn.close() + self.assertEqual(1, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) - cnn = self.connect() - xid = cnn.xid(1, "gtrid", "bqual") - cnn.tpc_rollback(xid) + cnn = self.connect() + xid = cnn.xid(1, "gtrid", "bqual") + cnn.tpc_rollback(xid) - self.assertEqual(cnn.status, ext.STATUS_READY) - self.assertEqual(0, self.count_xacts()) - self.assertEqual(0, self.count_test_records()) + self.assertEqual(cnn.status, ext.STATUS_READY) + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) def test_status_after_recover(self): cnn = self.connect() @@ -1078,28 +1074,28 @@ class ConnectionTwoPhaseTests(ConnectingTestCase): cnn.tpc_recover() self.assertEqual(ext.STATUS_READY, cnn.status) - cur = cnn.cursor() - cur.execute("select 1") - self.assertEqual(ext.STATUS_BEGIN, cnn.status) - cnn.tpc_recover() - self.assertEqual(ext.STATUS_BEGIN, cnn.status) + with cnn.cursor() as cur: + cur.execute("select 1") + self.assertEqual(ext.STATUS_BEGIN, cnn.status) + cnn.tpc_recover() + self.assertEqual(ext.STATUS_BEGIN, cnn.status) def test_recovered_xids(self): # insert a few test xns cnn = self.connect() cnn.set_isolation_level(0) - cur = cnn.cursor() - cur.execute("begin; prepare transaction '1-foo';") - cur.execute("begin; prepare transaction '2-bar';") + with cnn.cursor() as cur: + cur.execute("begin; prepare transaction '1-foo';") + cur.execute("begin; prepare transaction '2-bar';") - # read the values to return - cur.execute(""" - select gid, prepared, owner, database - from pg_prepared_xacts - where database = %s;""", - (dbname,)) - okvals = cur.fetchall() - okvals.sort() + # read the values to return + cur.execute(""" + select gid, prepared, owner, database + from pg_prepared_xacts + where database = %s;""", + (dbname,)) + okvals = cur.fetchall() + okvals.sort() cnn = self.connect() xids = cnn.tpc_recover() @@ -1121,10 +1117,10 @@ class ConnectionTwoPhaseTests(ConnectingTestCase): cnn.tpc_prepare() cnn = self.connect() - cur = cnn.cursor() - cur.execute("select gid from pg_prepared_xacts where database = %s;", - (dbname,)) - self.assertEqual('42_Z3RyaWQ=_YnF1YWw=', cur.fetchone()[0]) + with cnn.cursor() as cur: + cur.execute("select gid from pg_prepared_xacts where database = %s;", + (dbname,)) + self.assertEqual('42_Z3RyaWQ=_YnF1YWw=', cur.fetchone()[0]) @slow def test_xid_roundtrip(self): @@ -1248,71 +1244,71 @@ class TransactionControlTests(ConnectingTestCase): ext.ISOLATION_LEVEL_SERIALIZABLE) def test_not_in_transaction(self): - cur = self.conn.cursor() - cur.execute("select 1") - self.assertRaises(psycopg2.ProgrammingError, - self.conn.set_session, - ext.ISOLATION_LEVEL_SERIALIZABLE) + with self.conn.cursor() as cur: + cur.execute("select 1") + self.assertRaises(psycopg2.ProgrammingError, + self.conn.set_session, + ext.ISOLATION_LEVEL_SERIALIZABLE) def test_set_isolation_level(self): - cur = self.conn.cursor() - self.conn.set_session( - ext.ISOLATION_LEVEL_SERIALIZABLE) - cur.execute("SHOW transaction_isolation;") - self.assertEqual(cur.fetchone()[0], 'serializable') - self.conn.rollback() - - self.conn.set_session( - ext.ISOLATION_LEVEL_REPEATABLE_READ) - cur.execute("SHOW transaction_isolation;") - if self.conn.info.server_version > 80000: - self.assertEqual(cur.fetchone()[0], 'repeatable read') - else: + with self.conn.cursor() as cur: + self.conn.set_session( + ext.ISOLATION_LEVEL_SERIALIZABLE) + cur.execute("SHOW transaction_isolation;") self.assertEqual(cur.fetchone()[0], 'serializable') - self.conn.rollback() + self.conn.rollback() - self.conn.set_session( - isolation_level=ext.ISOLATION_LEVEL_READ_COMMITTED) - cur.execute("SHOW transaction_isolation;") - self.assertEqual(cur.fetchone()[0], 'read committed') - self.conn.rollback() + self.conn.set_session( + ext.ISOLATION_LEVEL_REPEATABLE_READ) + cur.execute("SHOW transaction_isolation;") + if self.conn.info.server_version > 80000: + self.assertEqual(cur.fetchone()[0], 'repeatable read') + else: + self.assertEqual(cur.fetchone()[0], 'serializable') + self.conn.rollback() - self.conn.set_session( - isolation_level=ext.ISOLATION_LEVEL_READ_UNCOMMITTED) - cur.execute("SHOW transaction_isolation;") - if self.conn.info.server_version > 80000: - self.assertEqual(cur.fetchone()[0], 'read uncommitted') - else: + self.conn.set_session( + isolation_level=ext.ISOLATION_LEVEL_READ_COMMITTED) + cur.execute("SHOW transaction_isolation;") self.assertEqual(cur.fetchone()[0], 'read committed') - self.conn.rollback() + self.conn.rollback() + + self.conn.set_session( + isolation_level=ext.ISOLATION_LEVEL_READ_UNCOMMITTED) + cur.execute("SHOW transaction_isolation;") + if self.conn.info.server_version > 80000: + self.assertEqual(cur.fetchone()[0], 'read uncommitted') + else: + self.assertEqual(cur.fetchone()[0], 'read committed') + self.conn.rollback() def test_set_isolation_level_str(self): - cur = self.conn.cursor() - self.conn.set_session("serializable") - cur.execute("SHOW transaction_isolation;") - self.assertEqual(cur.fetchone()[0], 'serializable') - self.conn.rollback() - - self.conn.set_session("repeatable read") - cur.execute("SHOW transaction_isolation;") - if self.conn.info.server_version > 80000: - self.assertEqual(cur.fetchone()[0], 'repeatable read') - else: + with self.conn.cursor() as cur: + self.conn.set_session("serializable") + cur.execute("SHOW transaction_isolation;") self.assertEqual(cur.fetchone()[0], 'serializable') - self.conn.rollback() + self.conn.rollback() - self.conn.set_session("read committed") - cur.execute("SHOW transaction_isolation;") - self.assertEqual(cur.fetchone()[0], 'read committed') - self.conn.rollback() + self.conn.set_session("repeatable read") + cur.execute("SHOW transaction_isolation;") + if self.conn.info.server_version > 80000: + self.assertEqual(cur.fetchone()[0], 'repeatable read') + else: + self.assertEqual(cur.fetchone()[0], 'serializable') + self.conn.rollback() - self.conn.set_session("read uncommitted") - cur.execute("SHOW transaction_isolation;") - if self.conn.info.server_version > 80000: - self.assertEqual(cur.fetchone()[0], 'read uncommitted') - else: + self.conn.set_session("read committed") + cur.execute("SHOW transaction_isolation;") self.assertEqual(cur.fetchone()[0], 'read committed') - self.conn.rollback() + self.conn.rollback() + + self.conn.set_session("read uncommitted") + cur.execute("SHOW transaction_isolation;") + if self.conn.info.server_version > 80000: + self.assertEqual(cur.fetchone()[0], 'read uncommitted') + else: + self.assertEqual(cur.fetchone()[0], 'read committed') + self.conn.rollback() def test_bad_isolation_level(self): self.assertRaises(ValueError, self.conn.set_session, 0) @@ -1322,87 +1318,87 @@ class TransactionControlTests(ConnectingTestCase): def test_set_read_only(self): self.assert_(self.conn.readonly is None) - cur = self.conn.cursor() - self.conn.set_session(readonly=True) - self.assert_(self.conn.readonly is True) - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'on') - self.conn.rollback() - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'on') - self.conn.rollback() + with self.conn.cursor() as cur: + self.conn.set_session(readonly=True) + self.assert_(self.conn.readonly is True) + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'on') + self.conn.rollback() + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'on') + self.conn.rollback() - self.conn.set_session(readonly=False) - self.assert_(self.conn.readonly is False) - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'off') - self.conn.rollback() + self.conn.set_session(readonly=False) + self.assert_(self.conn.readonly is False) + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'off') + self.conn.rollback() def test_setattr_read_only(self): - cur = self.conn.cursor() - self.conn.readonly = True - self.assert_(self.conn.readonly is True) - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'on') - self.assertRaises(self.conn.ProgrammingError, - setattr, self.conn, 'readonly', False) - self.assert_(self.conn.readonly is True) - self.conn.rollback() - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'on') - self.conn.rollback() + with self.conn.cursor() as cur: + self.conn.readonly = True + self.assert_(self.conn.readonly is True) + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'on') + self.assertRaises(self.conn.ProgrammingError, + setattr, self.conn, 'readonly', False) + self.assert_(self.conn.readonly is True) + self.conn.rollback() + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'on') + self.conn.rollback() - cur = self.conn.cursor() - self.conn.readonly = None - self.assert_(self.conn.readonly is None) - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'off') # assume defined by server - self.conn.rollback() + with self.conn.cursor() as cur: + self.conn.readonly = None + self.assert_(self.conn.readonly is None) + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'off') # assume defined by server + self.conn.rollback() - self.conn.readonly = False - self.assert_(self.conn.readonly is False) - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'off') - self.conn.rollback() + self.conn.readonly = False + self.assert_(self.conn.readonly is False) + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'off') + self.conn.rollback() def test_set_default(self): - cur = self.conn.cursor() - cur.execute("SHOW transaction_isolation;") - isolevel = cur.fetchone()[0] - cur.execute("SHOW transaction_read_only;") - readonly = cur.fetchone()[0] - self.conn.rollback() + with self.conn.cursor() as cur: + cur.execute("SHOW transaction_isolation;") + isolevel = cur.fetchone()[0] + cur.execute("SHOW transaction_read_only;") + readonly = cur.fetchone()[0] + self.conn.rollback() - self.conn.set_session(isolation_level='serializable', readonly=True) - self.conn.set_session(isolation_level='default', readonly='default') + self.conn.set_session(isolation_level='serializable', readonly=True) + self.conn.set_session(isolation_level='default', readonly='default') - cur.execute("SHOW transaction_isolation;") - self.assertEqual(cur.fetchone()[0], isolevel) - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], readonly) + cur.execute("SHOW transaction_isolation;") + self.assertEqual(cur.fetchone()[0], isolevel) + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], readonly) @skip_before_postgres(9, 1) def test_set_deferrable(self): self.assert_(self.conn.deferrable is None) - cur = self.conn.cursor() - self.conn.set_session(readonly=True, deferrable=True) - self.assert_(self.conn.deferrable is True) - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'on') - cur.execute("SHOW transaction_deferrable;") - self.assertEqual(cur.fetchone()[0], 'on') - self.conn.rollback() - cur.execute("SHOW transaction_deferrable;") - self.assertEqual(cur.fetchone()[0], 'on') - self.conn.rollback() + with self.conn.cursor() as cur: + self.conn.set_session(readonly=True, deferrable=True) + self.assert_(self.conn.deferrable is True) + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'on') + cur.execute("SHOW transaction_deferrable;") + self.assertEqual(cur.fetchone()[0], 'on') + self.conn.rollback() + cur.execute("SHOW transaction_deferrable;") + self.assertEqual(cur.fetchone()[0], 'on') + self.conn.rollback() - self.conn.set_session(deferrable=False) - self.assert_(self.conn.deferrable is False) - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'on') - cur.execute("SHOW transaction_deferrable;") - self.assertEqual(cur.fetchone()[0], 'off') - self.conn.rollback() + self.conn.set_session(deferrable=False) + self.assert_(self.conn.deferrable is False) + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'on') + cur.execute("SHOW transaction_deferrable;") + self.assertEqual(cur.fetchone()[0], 'off') + self.conn.rollback() @skip_after_postgres(9, 1) def test_set_deferrable_error(self): @@ -1413,49 +1409,49 @@ class TransactionControlTests(ConnectingTestCase): @skip_before_postgres(9, 1) def test_setattr_deferrable(self): - cur = self.conn.cursor() - self.conn.deferrable = True - self.assert_(self.conn.deferrable is True) - cur.execute("SHOW transaction_deferrable;") - self.assertEqual(cur.fetchone()[0], 'on') - self.assertRaises(self.conn.ProgrammingError, - setattr, self.conn, 'deferrable', False) - self.assert_(self.conn.deferrable is True) - self.conn.rollback() - cur.execute("SHOW transaction_deferrable;") - self.assertEqual(cur.fetchone()[0], 'on') - self.conn.rollback() + with self.conn.cursor() as cur: + self.conn.deferrable = True + self.assert_(self.conn.deferrable is True) + cur.execute("SHOW transaction_deferrable;") + self.assertEqual(cur.fetchone()[0], 'on') + self.assertRaises(self.conn.ProgrammingError, + setattr, self.conn, 'deferrable', False) + self.assert_(self.conn.deferrable is True) + self.conn.rollback() + cur.execute("SHOW transaction_deferrable;") + self.assertEqual(cur.fetchone()[0], 'on') + self.conn.rollback() - cur = self.conn.cursor() - self.conn.deferrable = None - self.assert_(self.conn.deferrable is None) - cur.execute("SHOW transaction_deferrable;") - self.assertEqual(cur.fetchone()[0], 'off') # assume defined by server - self.conn.rollback() + with self.conn.cursor() as cur: + self.conn.deferrable = None + self.assert_(self.conn.deferrable is None) + cur.execute("SHOW transaction_deferrable;") + self.assertEqual(cur.fetchone()[0], 'off') # assume defined by server + self.conn.rollback() - self.conn.deferrable = False - self.assert_(self.conn.deferrable is False) - cur.execute("SHOW transaction_deferrable;") - self.assertEqual(cur.fetchone()[0], 'off') - self.conn.rollback() + self.conn.deferrable = False + self.assert_(self.conn.deferrable is False) + cur.execute("SHOW transaction_deferrable;") + self.assertEqual(cur.fetchone()[0], 'off') + self.conn.rollback() def test_mixing_session_attribs(self): - cur = self.conn.cursor() - self.conn.autocommit = True - self.conn.readonly = True + with self.conn.cursor() as cur: + self.conn.autocommit = True + self.conn.readonly = True - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'on') + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'on') - cur.execute("SHOW default_transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'on') + cur.execute("SHOW default_transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'on') - self.conn.autocommit = False - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'on') + self.conn.autocommit = False + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'on') - cur.execute("SHOW default_transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'off') + cur.execute("SHOW default_transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'off') def test_idempotence_check(self): self.conn.autocommit = False @@ -1463,9 +1459,9 @@ class TransactionControlTests(ConnectingTestCase): self.conn.autocommit = True self.conn.readonly = True - cur = self.conn.cursor() - cur.execute("SHOW transaction_read_only") - self.assertEqual(cur.fetchone()[0], 'on') + with self.conn.cursor() as cur: + cur.execute("SHOW transaction_read_only") + self.assertEqual(cur.fetchone()[0], 'on') class TestEncryptPassword(ConnectingTestCase): @@ -1486,26 +1482,26 @@ class TestEncryptPassword(ConnectingTestCase): @skip_before_libpq(10) @skip_before_postgres(10) def test_encrypt_server(self): - cur = self.conn.cursor() - cur.execute("SHOW password_encryption;") - server_encryption_algorithm = cur.fetchone()[0] + with self.conn.cursor() as cur: + cur.execute("SHOW password_encryption;") + server_encryption_algorithm = cur.fetchone()[0] - enc_password = ext.encrypt_password( - 'psycopg2', 'ashesh', self.conn) + enc_password = ext.encrypt_password( + 'psycopg2', 'ashesh', self.conn) + + if server_encryption_algorithm == 'md5': + self.assertEqual( + enc_password, 'md594839d658c28a357126f105b9cb14cfc') + elif server_encryption_algorithm == 'scram-sha-256': + self.assertEqual(enc_password[:14], 'SCRAM-SHA-256$') - if server_encryption_algorithm == 'md5': self.assertEqual( - enc_password, 'md594839d658c28a357126f105b9cb14cfc') - elif server_encryption_algorithm == 'scram-sha-256': - self.assertEqual(enc_password[:14], 'SCRAM-SHA-256$') + ext.encrypt_password( + 'psycopg2', 'ashesh', self.conn, 'scram-sha-256' + )[:14], 'SCRAM-SHA-256$') - self.assertEqual( - ext.encrypt_password( - 'psycopg2', 'ashesh', self.conn, 'scram-sha-256' - )[:14], 'SCRAM-SHA-256$') - - self.assertRaises(psycopg2.ProgrammingError, - ext.encrypt_password, 'psycopg2', 'ashesh', self.conn, 'abc') + self.assertRaises(psycopg2.ProgrammingError, + ext.encrypt_password, 'psycopg2', 'ashesh', self.conn, 'abc') def test_encrypt_md5(self): self.assertEqual( @@ -1568,16 +1564,16 @@ class AutocommitTests(ConnectingTestCase): self.assertEqual(self.conn.info.transaction_status, ext.TRANSACTION_STATUS_IDLE) - cur = self.conn.cursor() - cur.execute('select 1;') - self.assertEqual(self.conn.status, ext.STATUS_BEGIN) - self.assertEqual(self.conn.info.transaction_status, - ext.TRANSACTION_STATUS_INTRANS) + with self.conn.cursor() as cur: + cur.execute('select 1;') + self.assertEqual(self.conn.status, ext.STATUS_BEGIN) + self.assertEqual(self.conn.info.transaction_status, + ext.TRANSACTION_STATUS_INTRANS) - self.conn.rollback() - self.assertEqual(self.conn.status, ext.STATUS_READY) - self.assertEqual(self.conn.info.transaction_status, - ext.TRANSACTION_STATUS_IDLE) + self.conn.rollback() + self.assertEqual(self.conn.status, ext.STATUS_READY) + self.assertEqual(self.conn.info.transaction_status, + ext.TRANSACTION_STATUS_IDLE) def test_set_autocommit(self): self.conn.autocommit = True @@ -1586,28 +1582,28 @@ class AutocommitTests(ConnectingTestCase): self.assertEqual(self.conn.info.transaction_status, ext.TRANSACTION_STATUS_IDLE) - cur = self.conn.cursor() - cur.execute('select 1;') - self.assertEqual(self.conn.status, ext.STATUS_READY) - self.assertEqual(self.conn.info.transaction_status, - ext.TRANSACTION_STATUS_IDLE) + with self.conn.cursor() as cur: + cur.execute('select 1;') + self.assertEqual(self.conn.status, ext.STATUS_READY) + self.assertEqual(self.conn.info.transaction_status, + ext.TRANSACTION_STATUS_IDLE) - self.conn.autocommit = False - self.assert_(not self.conn.autocommit) - self.assertEqual(self.conn.status, ext.STATUS_READY) - self.assertEqual(self.conn.info.transaction_status, - ext.TRANSACTION_STATUS_IDLE) + self.conn.autocommit = False + self.assert_(not self.conn.autocommit) + self.assertEqual(self.conn.status, ext.STATUS_READY) + self.assertEqual(self.conn.info.transaction_status, + ext.TRANSACTION_STATUS_IDLE) - cur.execute('select 1;') - self.assertEqual(self.conn.status, ext.STATUS_BEGIN) - self.assertEqual(self.conn.info.transaction_status, - ext.TRANSACTION_STATUS_INTRANS) + cur.execute('select 1;') + self.assertEqual(self.conn.status, ext.STATUS_BEGIN) + self.assertEqual(self.conn.info.transaction_status, + ext.TRANSACTION_STATUS_INTRANS) def test_set_intrans_error(self): - cur = self.conn.cursor() - cur.execute('select 1;') - self.assertRaises(psycopg2.ProgrammingError, - setattr, self.conn, 'autocommit', True) + with self.conn.cursor() as cur: + cur.execute('select 1;') + self.assertRaises(psycopg2.ProgrammingError, + setattr, self.conn, 'autocommit', True) def test_set_session_autocommit(self): self.conn.set_session(autocommit=True) @@ -1616,34 +1612,34 @@ class AutocommitTests(ConnectingTestCase): self.assertEqual(self.conn.info.transaction_status, ext.TRANSACTION_STATUS_IDLE) - cur = self.conn.cursor() - cur.execute('select 1;') - self.assertEqual(self.conn.status, ext.STATUS_READY) - self.assertEqual(self.conn.info.transaction_status, - ext.TRANSACTION_STATUS_IDLE) + with self.conn.cursor() as cur: + cur.execute('select 1;') + self.assertEqual(self.conn.status, ext.STATUS_READY) + self.assertEqual(self.conn.info.transaction_status, + ext.TRANSACTION_STATUS_IDLE) - self.conn.set_session(autocommit=False) - self.assert_(not self.conn.autocommit) - self.assertEqual(self.conn.status, ext.STATUS_READY) - self.assertEqual(self.conn.info.transaction_status, - ext.TRANSACTION_STATUS_IDLE) + self.conn.set_session(autocommit=False) + self.assert_(not self.conn.autocommit) + self.assertEqual(self.conn.status, ext.STATUS_READY) + self.assertEqual(self.conn.info.transaction_status, + ext.TRANSACTION_STATUS_IDLE) - cur.execute('select 1;') - self.assertEqual(self.conn.status, ext.STATUS_BEGIN) - self.assertEqual(self.conn.info.transaction_status, - ext.TRANSACTION_STATUS_INTRANS) - self.conn.rollback() + cur.execute('select 1;') + self.assertEqual(self.conn.status, ext.STATUS_BEGIN) + self.assertEqual(self.conn.info.transaction_status, + ext.TRANSACTION_STATUS_INTRANS) + self.conn.rollback() - self.conn.set_session('serializable', readonly=True, autocommit=True) - self.assert_(self.conn.autocommit) - cur.execute('select 1;') - self.assertEqual(self.conn.status, ext.STATUS_READY) - self.assertEqual(self.conn.info.transaction_status, - ext.TRANSACTION_STATUS_IDLE) - cur.execute("SHOW transaction_isolation;") - self.assertEqual(cur.fetchone()[0], 'serializable') - cur.execute("SHOW transaction_read_only;") - self.assertEqual(cur.fetchone()[0], 'on') + self.conn.set_session('serializable', readonly=True, autocommit=True) + self.assert_(self.conn.autocommit) + cur.execute('select 1;') + self.assertEqual(self.conn.status, ext.STATUS_READY) + self.assertEqual(self.conn.info.transaction_status, + ext.TRANSACTION_STATUS_IDLE) + cur.execute("SHOW transaction_isolation;") + self.assertEqual(cur.fetchone()[0], 'serializable') + cur.execute("SHOW transaction_read_only;") + self.assertEqual(cur.fetchone()[0], 'on') class PasswordLeakTestCase(ConnectingTestCase): @@ -1762,10 +1758,10 @@ class TestConnectionInfo(ConnectingTestCase): self.assert_(self.bconn.info.dbname is None) def test_user(self): - cur = self.conn.cursor() - cur.execute("select user") - self.assertEqual(self.conn.info.user, cur.fetchone()[0]) - self.assert_(self.bconn.info.user is None) + with self.conn.cursor() as cur: + cur.execute("select user") + self.assertEqual(self.conn.info.user, cur.fetchone()[0]) + self.assert_(self.bconn.info.user is None) def test_password(self): self.assert_(isinstance(self.conn.info.password, str)) @@ -1801,69 +1797,69 @@ class TestConnectionInfo(ConnectingTestCase): def test_transaction_status(self): self.assertEqual(self.conn.info.transaction_status, 0) - cur = self.conn.cursor() - cur.execute("select 1") - self.assertEqual(self.conn.info.transaction_status, 2) - self.assertEqual(self.bconn.info.transaction_status, 4) + with self.conn.cursor() as cur: + cur.execute("select 1") + self.assertEqual(self.conn.info.transaction_status, 2) + self.assertEqual(self.bconn.info.transaction_status, 4) def test_parameter_status(self): - cur = self.conn.cursor() - try: - cur.execute("show server_version") - except psycopg2.DatabaseError: - self.assertIsInstance( - self.conn.info.parameter_status('server_version'), str) - else: - self.assertEqual( - self.conn.info.parameter_status('server_version'), - cur.fetchone()[0]) + with self.conn.cursor() as cur: + try: + cur.execute("show server_version") + except psycopg2.DatabaseError: + self.assertIsInstance( + self.conn.info.parameter_status('server_version'), str) + else: + self.assertEqual( + self.conn.info.parameter_status('server_version'), + cur.fetchone()[0]) - self.assertIsNone(self.conn.info.parameter_status('wat')) - self.assertIsNone(self.bconn.info.parameter_status('server_version')) + self.assertIsNone(self.conn.info.parameter_status('wat')) + self.assertIsNone(self.bconn.info.parameter_status('server_version')) def test_protocol_version(self): self.assertEqual(self.conn.info.protocol_version, 3) self.assertEqual(self.bconn.info.protocol_version, 0) def test_server_version(self): - cur = self.conn.cursor() - try: - cur.execute("show server_version_num") - except psycopg2.DatabaseError: - self.assert_(isinstance(self.conn.info.server_version, int)) - else: - self.assertEqual( - self.conn.info.server_version, int(cur.fetchone()[0])) + with self.conn.cursor() as cur: + try: + cur.execute("show server_version_num") + except psycopg2.DatabaseError: + self.assert_(isinstance(self.conn.info.server_version, int)) + else: + self.assertEqual( + self.conn.info.server_version, int(cur.fetchone()[0])) - self.assertEqual(self.bconn.info.server_version, 0) + self.assertEqual(self.bconn.info.server_version, 0) def test_error_message(self): self.assertIsNone(self.conn.info.error_message) self.assertIsNotNone(self.bconn.info.error_message) - cur = self.conn.cursor() - try: - cur.execute("select 1 from nosuchtable") - except psycopg2.DatabaseError: - pass + with self.conn.cursor() as cur: + try: + cur.execute("select 1 from nosuchtable") + except psycopg2.DatabaseError: + pass - self.assert_('nosuchtable' in self.conn.info.error_message) + self.assert_('nosuchtable' in self.conn.info.error_message) def test_socket(self): self.assert_(self.conn.info.socket >= 0) self.assert_(self.bconn.info.socket < 0) def test_backend_pid(self): - cur = self.conn.cursor() - try: - cur.execute("select pg_backend_pid()") - except psycopg2.DatabaseError: - self.assert_(self.conn.info.backend_pid > 0) - else: - self.assertEqual( - self.conn.info.backend_pid, int(cur.fetchone()[0])) + with self.conn.cursor() as cur: + try: + cur.execute("select pg_backend_pid()") + except psycopg2.DatabaseError: + self.assert_(self.conn.info.backend_pid > 0) + else: + self.assertEqual( + self.conn.info.backend_pid, int(cur.fetchone()[0])) - self.assert_(self.bconn.info.backend_pid == 0) + self.assert_(self.bconn.info.backend_pid == 0) def test_needs_password(self): self.assertIs(self.conn.info.needs_password, False) diff --git a/tests/test_copy.py b/tests/test_copy.py index 4fdf1641..37a2416b 100755 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -66,12 +66,12 @@ class CopyTests(ConnectingTestCase): self._create_temp_table() def _create_temp_table(self): - curs = self.conn.cursor() - curs.execute(''' - CREATE TEMPORARY TABLE tcopy ( - id serial PRIMARY KEY, - data text - )''') + with self.conn.cursor() as curs: + curs.execute(''' + CREATE TEMPORARY TABLE tcopy ( + id serial PRIMARY KEY, + data text + )''') @slow def test_copy_from(self): @@ -92,31 +92,31 @@ class CopyTests(ConnectingTestCase): curs.close() def test_copy_from_cols(self): - curs = self.conn.cursor() - f = StringIO() - for i in range(10): - f.write("%s\n" % (i,)) + with self.conn.cursor() as curs: + f = StringIO() + for i in range(10): + f.write("%s\n" % (i,)) - f.seek(0) - curs.copy_from(MinimalRead(f), "tcopy", columns=['id']) + f.seek(0) + curs.copy_from(MinimalRead(f), "tcopy", columns=['id']) - curs.execute("select * from tcopy order by id") - self.assertEqual([(i, None) for i in range(10)], curs.fetchall()) + curs.execute("select * from tcopy order by id") + self.assertEqual([(i, None) for i in range(10)], curs.fetchall()) def test_copy_from_cols_err(self): - curs = self.conn.cursor() - f = StringIO() - for i in range(10): - f.write("%s\n" % (i,)) + with self.conn.cursor() as curs: + f = StringIO() + for i in range(10): + f.write("%s\n" % (i,)) - f.seek(0) + f.seek(0) - def cols(): - raise ZeroDivisionError() - yield 'id' + def cols(): + raise ZeroDivisionError() + yield 'id' - self.assertRaises(ZeroDivisionError, - curs.copy_from, MinimalRead(f), "tcopy", columns=cols()) + self.assertRaises(ZeroDivisionError, + curs.copy_from, MinimalRead(f), "tcopy", columns=cols()) @slow def test_copy_to(self): @@ -140,14 +140,14 @@ class CopyTests(ConnectingTestCase): + list(range(160, 256))).decode('latin1') about = abin.replace('\\', '\\\\') - curs = self.conn.cursor() - curs.execute('insert into tcopy values (%s, %s)', - (42, abin)) + with self.conn.cursor() as curs: + curs.execute('insert into tcopy values (%s, %s)', + (42, abin)) - f = io.StringIO() - curs.copy_to(f, 'tcopy', columns=('data',)) - f.seek(0) - self.assertEqual(f.readline().rstrip(), about) + f = io.StringIO() + curs.copy_to(f, 'tcopy', columns=('data',)) + f.seek(0) + self.assertEqual(f.readline().rstrip(), about) def test_copy_bytes(self): self.conn.set_client_encoding('latin1') @@ -161,14 +161,14 @@ class CopyTests(ConnectingTestCase): + list(range(160, 255))).decode('latin1') about = abin.replace('\\', '\\\\').encode('latin1') - curs = self.conn.cursor() - curs.execute('insert into tcopy values (%s, %s)', - (42, abin)) + with self.conn.cursor() as curs: + curs.execute('insert into tcopy values (%s, %s)', + (42, abin)) - f = io.BytesIO() - curs.copy_to(f, 'tcopy', columns=('data',)) - f.seek(0) - self.assertEqual(f.readline().rstrip(), about) + f = io.BytesIO() + curs.copy_to(f, 'tcopy', columns=('data',)) + f.seek(0) + self.assertEqual(f.readline().rstrip(), about) def test_copy_expert_textiobase(self): self.conn.set_client_encoding('latin1') @@ -188,35 +188,35 @@ class CopyTests(ConnectingTestCase): f.write(about) f.seek(0) - curs = self.conn.cursor() - psycopg2.extensions.register_type( - psycopg2.extensions.UNICODE, curs) + with self.conn.cursor() as curs: + psycopg2.extensions.register_type( + psycopg2.extensions.UNICODE, curs) - curs.copy_expert('COPY tcopy (data) FROM STDIN', f) - curs.execute("select data from tcopy;") - self.assertEqual(curs.fetchone()[0], abin) + curs.copy_expert('COPY tcopy (data) FROM STDIN', f) + curs.execute("select data from tcopy;") + self.assertEqual(curs.fetchone()[0], abin) - f = io.StringIO() - curs.copy_expert('COPY tcopy (data) TO STDOUT', f) - f.seek(0) - self.assertEqual(f.readline().rstrip(), about) + f = io.StringIO() + curs.copy_expert('COPY tcopy (data) TO STDOUT', f) + f.seek(0) + self.assertEqual(f.readline().rstrip(), about) - # same tests with setting size - f = io.StringIO() - f.write(about) - f.seek(0) - exp_size = 123 - # hack here to leave file as is, only check size when reading - real_read = f.read + # same tests with setting size + f = io.StringIO() + f.write(about) + f.seek(0) + exp_size = 123 + # hack here to leave file as is, only check size when reading + real_read = f.read - def read(_size, f=f, exp_size=exp_size): - self.assertEqual(_size, exp_size) - return real_read(_size) + def read(_size, f=f, exp_size=exp_size): + self.assertEqual(_size, exp_size) + return real_read(_size) - f.read = read - curs.copy_expert('COPY tcopy (data) FROM STDIN', f, size=exp_size) - curs.execute("select data from tcopy;") - self.assertEqual(curs.fetchone()[0], abin) + f.read = read + curs.copy_expert('COPY tcopy (data) FROM STDIN', f, size=exp_size) + curs.execute("select data from tcopy;") + self.assertEqual(curs.fetchone()[0], abin) def _copy_from(self, curs, nrecs, srec, copykw): f = StringIO() @@ -254,56 +254,54 @@ class CopyTests(ConnectingTestCase): pass f = Whatever() - curs = self.conn.cursor() - self.assertRaises(TypeError, - curs.copy_expert, 'COPY tcopy (data) FROM STDIN', f) + with self.conn.cursor() as curs: + self.assertRaises(TypeError, + curs.copy_expert, 'COPY tcopy (data) FROM STDIN', f) def test_copy_no_column_limit(self): cols = ["c%050d" % i for i in range(200)] - curs = self.conn.cursor() - curs.execute('CREATE TEMPORARY TABLE manycols (%s)' % ',\n'.join( - ["%s int" % c for c in cols])) - curs.execute("INSERT INTO manycols DEFAULT VALUES") + with self.conn.cursor() as curs: + curs.execute('CREATE TEMPORARY TABLE manycols (%s)' % ',\n'.join( + ["%s int" % c for c in cols])) + curs.execute("INSERT INTO manycols DEFAULT VALUES") - f = StringIO() - curs.copy_to(f, "manycols", columns=cols) - f.seek(0) - self.assertEqual(f.read().split(), ['\\N'] * len(cols)) + f = StringIO() + curs.copy_to(f, "manycols", columns=cols) + f.seek(0) + self.assertEqual(f.read().split(), ['\\N'] * len(cols)) - f.seek(0) - curs.copy_from(f, "manycols", columns=cols) - curs.execute("select count(*) from manycols;") - self.assertEqual(curs.fetchone()[0], 2) + f.seek(0) + curs.copy_from(f, "manycols", columns=cols) + curs.execute("select count(*) from manycols;") + self.assertEqual(curs.fetchone()[0], 2) @skip_before_postgres(8, 2) # they don't send the count def test_copy_rowcount(self): - curs = self.conn.cursor() + with self.conn.cursor() as curs: + curs.copy_from(StringIO('aaa\nbbb\nccc\n'), 'tcopy', columns=['data']) + self.assertEqual(curs.rowcount, 3) - curs.copy_from(StringIO('aaa\nbbb\nccc\n'), 'tcopy', columns=['data']) - self.assertEqual(curs.rowcount, 3) + curs.copy_expert( + "copy tcopy (data) from stdin", + StringIO('ddd\neee\n')) + self.assertEqual(curs.rowcount, 2) - curs.copy_expert( - "copy tcopy (data) from stdin", - StringIO('ddd\neee\n')) - self.assertEqual(curs.rowcount, 2) + curs.copy_to(StringIO(), "tcopy") + self.assertEqual(curs.rowcount, 5) - curs.copy_to(StringIO(), "tcopy") - self.assertEqual(curs.rowcount, 5) - - curs.execute("insert into tcopy (data) values ('fff')") - curs.copy_expert("copy tcopy to stdout", StringIO()) - self.assertEqual(curs.rowcount, 6) + curs.execute("insert into tcopy (data) values ('fff')") + curs.copy_expert("copy tcopy to stdout", StringIO()) + self.assertEqual(curs.rowcount, 6) def test_copy_rowcount_error(self): - curs = self.conn.cursor() + with self.conn.cursor() as curs: + curs.execute("insert into tcopy (data) values ('fff')") + self.assertEqual(curs.rowcount, 1) - curs.execute("insert into tcopy (data) values ('fff')") - self.assertEqual(curs.rowcount, 1) - - self.assertRaises(psycopg2.DataError, - curs.copy_from, StringIO('aaa\nbbb\nccc\n'), 'tcopy') - self.assertEqual(curs.rowcount, -1) + self.assertRaises(psycopg2.DataError, + curs.copy_from, StringIO('aaa\nbbb\nccc\n'), 'tcopy') + self.assertEqual(curs.rowcount, -1) @slow def test_copy_from_segfault(self): @@ -317,6 +315,7 @@ try: curs.execute("copy copy_segf from stdin") except psycopg2.ProgrammingError: pass +curs.close() conn.close() """ % {'dsn': dsn}) @@ -336,6 +335,7 @@ try: curs.execute("copy copy_segf to stdout") except psycopg2.ProgrammingError: pass +curs.close() conn.close() """ % {'dsn': dsn}) @@ -351,24 +351,24 @@ conn.close() def readline(self): return 1 / 0 - curs = self.conn.cursor() - # It seems we cannot do this, but now at least we propagate the error - # self.assertRaises(ZeroDivisionError, - # curs.copy_from, BrokenRead(), "tcopy") - try: - curs.copy_from(BrokenRead(), "tcopy") - except Exception as e: - self.assert_('ZeroDivisionError' in str(e)) + with self.conn.cursor() as curs: + # It seems we cannot do this, but now at least we propagate the error + # self.assertRaises(ZeroDivisionError, + # curs.copy_from, BrokenRead(), "tcopy") + try: + curs.copy_from(BrokenRead(), "tcopy") + except Exception as e: + self.assert_('ZeroDivisionError' in str(e)) def test_copy_to_propagate_error(self): class BrokenWrite(TextIOBase): def write(self, data): return 1 / 0 - curs = self.conn.cursor() - curs.execute("insert into tcopy values (10, 'hi')") - self.assertRaises(ZeroDivisionError, - curs.copy_to, BrokenWrite(), "tcopy") + with self.conn.cursor() as curs: + curs.execute("insert into tcopy values (10, 'hi')") + self.assertRaises(ZeroDivisionError, + curs.copy_to, BrokenWrite(), "tcopy") def test_suite(): diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 4d180962..c1e42137 100755 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -51,10 +51,10 @@ class CursorTests(ConnectingTestCase): self.assert_(cur.closed) def test_empty_query(self): - cur = self.conn.cursor() - self.assertRaises(psycopg2.ProgrammingError, cur.execute, "") - self.assertRaises(psycopg2.ProgrammingError, cur.execute, " ") - self.assertRaises(psycopg2.ProgrammingError, cur.execute, ";") + with self.conn.cursor() as cur: + self.assertRaises(psycopg2.ProgrammingError, cur.execute, "") + self.assertRaises(psycopg2.ProgrammingError, cur.execute, " ") + self.assertRaises(psycopg2.ProgrammingError, cur.execute, ";") def test_executemany_propagate_exceptions(self): conn = self.conn @@ -70,58 +70,57 @@ class CursorTests(ConnectingTestCase): def test_mogrify_unicode(self): conn = self.conn - cur = conn.cursor() + with conn.cursor() as cur: + # test consistency between execute and mogrify. - # test consistency between execute and mogrify. + # unicode query containing only ascii data + cur.execute(u"SELECT 'foo';") + self.assertEqual('foo', cur.fetchone()[0]) + self.assertEqual(b"SELECT 'foo';", cur.mogrify(u"SELECT 'foo';")) - # unicode query containing only ascii data - cur.execute(u"SELECT 'foo';") - self.assertEqual('foo', cur.fetchone()[0]) - self.assertEqual(b"SELECT 'foo';", cur.mogrify(u"SELECT 'foo';")) + conn.set_client_encoding('UTF8') + snowman = u"\u2603" - conn.set_client_encoding('UTF8') - snowman = u"\u2603" + def b(s): + if isinstance(s, text_type): + return s.encode('utf8') + else: + return s - def b(s): - if isinstance(s, text_type): - return s.encode('utf8') - else: - return s + # unicode query with non-ascii data + cur.execute(u"SELECT '%s';" % snowman) + self.assertEqual(snowman.encode('utf8'), b(cur.fetchone()[0])) + self.assertQuotedEqual(("SELECT '%s';" % snowman).encode('utf8'), + cur.mogrify(u"SELECT '%s';" % snowman)) - # unicode query with non-ascii data - cur.execute(u"SELECT '%s';" % snowman) - self.assertEqual(snowman.encode('utf8'), b(cur.fetchone()[0])) - self.assertQuotedEqual(("SELECT '%s';" % snowman).encode('utf8'), - cur.mogrify(u"SELECT '%s';" % snowman)) + # unicode args + cur.execute("SELECT %s;", (snowman,)) + self.assertEqual(snowman.encode("utf-8"), b(cur.fetchone()[0])) + self.assertQuotedEqual(("SELECT '%s';" % snowman).encode('utf8'), + cur.mogrify("SELECT %s;", (snowman,))) - # unicode args - cur.execute("SELECT %s;", (snowman,)) - self.assertEqual(snowman.encode("utf-8"), b(cur.fetchone()[0])) - self.assertQuotedEqual(("SELECT '%s';" % snowman).encode('utf8'), - cur.mogrify("SELECT %s;", (snowman,))) - - # unicode query and args - cur.execute(u"SELECT %s;", (snowman,)) - self.assertEqual(snowman.encode("utf-8"), b(cur.fetchone()[0])) - self.assertQuotedEqual(("SELECT '%s';" % snowman).encode('utf8'), - cur.mogrify(u"SELECT %s;", (snowman,))) + # unicode query and args + cur.execute(u"SELECT %s;", (snowman,)) + self.assertEqual(snowman.encode("utf-8"), b(cur.fetchone()[0])) + self.assertQuotedEqual(("SELECT '%s';" % snowman).encode('utf8'), + cur.mogrify(u"SELECT %s;", (snowman,))) def test_mogrify_decimal_explodes(self): conn = self.conn - cur = conn.cursor() - self.assertEqual(b'SELECT 10.3;', - cur.mogrify("SELECT %s;", (Decimal("10.3"),))) + with conn.cursor() as cur: + self.assertEqual(b'SELECT 10.3;', + cur.mogrify("SELECT %s;", (Decimal("10.3"),))) @skip_if_no_getrefcount def test_mogrify_leak_on_multiple_reference(self): # issue #81: reference leak when a parameter value is referenced # more than once from a dict. - cur = self.conn.cursor() - foo = (lambda x: x)('foo') * 10 - nref1 = sys.getrefcount(foo) - cur.mogrify("select %(foo)s, %(foo)s, %(foo)s", {'foo': foo}) - nref2 = sys.getrefcount(foo) - self.assertEqual(nref1, nref2) + with self.conn.cursor() as cur: + foo = (lambda x: x)('foo') * 10 + nref1 = sys.getrefcount(foo) + cur.mogrify("select %(foo)s, %(foo)s, %(foo)s", {'foo': foo}) + nref2 = sys.getrefcount(foo) + self.assertEqual(nref1, nref2) def test_modify_closed(self): cur = self.conn.cursor() @@ -130,52 +129,51 @@ class CursorTests(ConnectingTestCase): self.assertEqual(sql, b"select 10") def test_bad_placeholder(self): - cur = self.conn.cursor() - self.assertRaises(psycopg2.ProgrammingError, - cur.mogrify, "select %(foo", {}) - self.assertRaises(psycopg2.ProgrammingError, - cur.mogrify, "select %(foo", {'foo': 1}) - self.assertRaises(psycopg2.ProgrammingError, - cur.mogrify, "select %(foo, %(bar)", {'foo': 1}) - self.assertRaises(psycopg2.ProgrammingError, - cur.mogrify, "select %(foo, %(bar)", {'foo': 1, 'bar': 2}) + with self.conn.cursor() as cur: + self.assertRaises(psycopg2.ProgrammingError, + cur.mogrify, "select %(foo", {}) + self.assertRaises(psycopg2.ProgrammingError, + cur.mogrify, "select %(foo", {'foo': 1}) + self.assertRaises(psycopg2.ProgrammingError, + cur.mogrify, "select %(foo, %(bar)", {'foo': 1}) + self.assertRaises(psycopg2.ProgrammingError, + cur.mogrify, "select %(foo, %(bar)", {'foo': 1, 'bar': 2}) def test_cast(self): - curs = self.conn.cursor() + with self.conn.cursor() as curs: + self.assertEqual(42, curs.cast(20, '42')) + self.assertAlmostEqual(3.14, curs.cast(700, '3.14')) - self.assertEqual(42, curs.cast(20, '42')) - self.assertAlmostEqual(3.14, curs.cast(700, '3.14')) + self.assertEqual(Decimal('123.45'), curs.cast(1700, '123.45')) - self.assertEqual(Decimal('123.45'), curs.cast(1700, '123.45')) - - self.assertEqual(date(2011, 1, 2), curs.cast(1082, '2011-01-02')) - self.assertEqual("who am i?", curs.cast(705, 'who am i?')) # unknown + self.assertEqual(date(2011, 1, 2), curs.cast(1082, '2011-01-02')) + self.assertEqual("who am i?", curs.cast(705, 'who am i?')) # unknown def test_cast_specificity(self): - curs = self.conn.cursor() - self.assertEqual("foo", curs.cast(705, 'foo')) + with self.conn.cursor() as curs: + self.assertEqual("foo", curs.cast(705, 'foo')) - D = psycopg2.extensions.new_type((705,), "DOUBLING", lambda v, c: v * 2) - psycopg2.extensions.register_type(D, self.conn) - self.assertEqual("foofoo", curs.cast(705, 'foo')) + D = psycopg2.extensions.new_type((705,), "DOUBLING", lambda v, c: v * 2) + psycopg2.extensions.register_type(D, self.conn) + self.assertEqual("foofoo", curs.cast(705, 'foo')) - T = psycopg2.extensions.new_type((705,), "TREBLING", lambda v, c: v * 3) - psycopg2.extensions.register_type(T, curs) - self.assertEqual("foofoofoo", curs.cast(705, 'foo')) + T = psycopg2.extensions.new_type((705,), "TREBLING", lambda v, c: v * 3) + psycopg2.extensions.register_type(T, curs) + self.assertEqual("foofoofoo", curs.cast(705, 'foo')) - curs2 = self.conn.cursor() - self.assertEqual("foofoo", curs2.cast(705, 'foo')) + with self.conn.cursor() as curs2: + self.assertEqual("foofoo", curs2.cast(705, 'foo')) def test_weakref(self): - curs = self.conn.cursor() - w = ref(curs) + with self.conn.cursor() as curs: + w = ref(curs) del curs gc.collect() self.assert_(w() is None) def test_null_name(self): - curs = self.conn.cursor(None) - self.assertEqual(curs.name, None) + with self.conn.cursor(None) as curs: + self.assertEqual(curs.name, None) def test_invalid_name(self): curs = self.conn.cursor() @@ -184,9 +182,9 @@ class CursorTests(ConnectingTestCase): curs.execute("insert into invname values (%s)", (i,)) curs.close() - curs = self.conn.cursor(r'1-2-3 \ "test"') - curs.execute("select data from invname order by data") - self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)]) + with self.conn.cursor(r'1-2-3 \ "test"') as curs: + curs.execute("select data from invname order by data") + self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)]) def _create_withhold_table(self): curs = self.conn.cursor() @@ -213,15 +211,15 @@ class CursorTests(ConnectingTestCase): self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)]) curs.close() - curs = self.conn.cursor("W", withhold=True) - self.assertEqual(curs.withhold, True) - curs.execute("select data from withhold order by data") - self.conn.commit() - self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)]) + with self.conn.cursor("W", withhold=True) as curs: + self.assertEqual(curs.withhold, True) + curs.execute("select data from withhold order by data") + self.conn.commit() + self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)]) - curs = self.conn.cursor() - curs.execute("drop table withhold") - self.conn.commit() + with self.conn.cursor() as curs: + curs.execute("drop table withhold") + self.conn.commit() def test_withhold_no_begin(self): self._create_withhold_table() @@ -328,134 +326,134 @@ class CursorTests(ConnectingTestCase): return self.skipTest("can't evaluate non-scrollable cursor") curs.close() - curs = self.conn.cursor("S", scrollable=False) - self.assertEqual(curs.scrollable, False) - curs.execute("select * from scrollable") - curs.scroll(2) - self.assertRaises(psycopg2.OperationalError, curs.scroll, -1) + with self.conn.cursor("S", scrollable=False) as curs: + self.assertEqual(curs.scrollable, False) + curs.execute("select * from scrollable") + curs.scroll(2) + self.assertRaises(psycopg2.OperationalError, curs.scroll, -1) @slow @skip_before_postgres(8, 2) def test_iter_named_cursor_efficient(self): - curs = self.conn.cursor('tmp') - # if these records are fetched in the same roundtrip their - # timestamp will not be influenced by the pause in Python world. - curs.execute("""select clock_timestamp() from generate_series(1,2)""") - i = iter(curs) - t1 = next(i)[0] - time.sleep(0.2) - t2 = next(i)[0] - self.assert_((t2 - t1).microseconds * 1e-6 < 0.1, - "named cursor records fetched in 2 roundtrips (delta: %s)" - % (t2 - t1)) + with self.conn.cursor('tmp') as curs: + # if these records are fetched in the same roundtrip their + # timestamp will not be influenced by the pause in Python world. + curs.execute("""select clock_timestamp() from generate_series(1,2)""") + i = iter(curs) + t1 = next(i)[0] + time.sleep(0.2) + t2 = next(i)[0] + self.assert_((t2 - t1).microseconds * 1e-6 < 0.1, + "named cursor records fetched in 2 roundtrips (delta: %s)" + % (t2 - t1)) @skip_before_postgres(8, 0) def test_iter_named_cursor_default_itersize(self): - curs = self.conn.cursor('tmp') - curs.execute('select generate_series(1,50)') - rv = [(r[0], curs.rownumber) for r in curs] - # everything swallowed in one gulp - self.assertEqual(rv, [(i, i) for i in range(1, 51)]) + with self.conn.cursor('tmp') as curs: + curs.execute('select generate_series(1,50)') + rv = [(r[0], curs.rownumber) for r in curs] + # everything swallowed in one gulp + self.assertEqual(rv, [(i, i) for i in range(1, 51)]) @skip_before_postgres(8, 0) def test_iter_named_cursor_itersize(self): - curs = self.conn.cursor('tmp') - curs.itersize = 30 - curs.execute('select generate_series(1,50)') - rv = [(r[0], curs.rownumber) for r in curs] - # everything swallowed in two gulps - self.assertEqual(rv, [(i, ((i - 1) % 30) + 1) for i in range(1, 51)]) + with self.conn.cursor('tmp') as curs: + curs.itersize = 30 + curs.execute('select generate_series(1,50)') + rv = [(r[0], curs.rownumber) for r in curs] + # everything swallowed in two gulps + self.assertEqual(rv, [(i, ((i - 1) % 30) + 1) for i in range(1, 51)]) @skip_before_postgres(8, 0) def test_iter_named_cursor_rownumber(self): - curs = self.conn.cursor('tmp') - # note: this fails if itersize < dataset: internally we check - # rownumber == rowcount to detect when to read anoter page, so we - # would need an extra attribute to have a monotonic rownumber. - curs.itersize = 20 - curs.execute('select generate_series(1,10)') - for i, rec in enumerate(curs): - self.assertEqual(i + 1, curs.rownumber) + with self.conn.cursor('tmp') as curs: + # note: this fails if itersize < dataset: internally we check + # rownumber == rowcount to detect when to read anoter page, so we + # would need an extra attribute to have a monotonic rownumber. + curs.itersize = 20 + curs.execute('select generate_series(1,10)') + for i, rec in enumerate(curs): + self.assertEqual(i + 1, curs.rownumber) def test_description_attribs(self): - curs = self.conn.cursor() - curs.execute("""select - 3.14::decimal(10,2) as pi, - 'hello'::text as hi, - '2010-02-18'::date as now; - """) - self.assertEqual(len(curs.description), 3) - for c in curs.description: - self.assertEqual(len(c), 7) # DBAPI happy - for a in ('name', 'type_code', 'display_size', 'internal_size', - 'precision', 'scale', 'null_ok'): - self.assert_(hasattr(c, a), a) + with self.conn.cursor() as curs: + curs.execute("""select + 3.14::decimal(10,2) as pi, + 'hello'::text as hi, + '2010-02-18'::date as now; + """) + self.assertEqual(len(curs.description), 3) + for c in curs.description: + self.assertEqual(len(c), 7) # DBAPI happy + for a in ('name', 'type_code', 'display_size', 'internal_size', + 'precision', 'scale', 'null_ok'): + self.assert_(hasattr(c, a), a) - c = curs.description[0] - self.assertEqual(c.name, 'pi') - self.assert_(c.type_code in psycopg2.extensions.DECIMAL.values) - self.assert_(c.internal_size > 0) - self.assertEqual(c.precision, 10) - self.assertEqual(c.scale, 2) + c = curs.description[0] + self.assertEqual(c.name, 'pi') + self.assert_(c.type_code in psycopg2.extensions.DECIMAL.values) + self.assert_(c.internal_size > 0) + self.assertEqual(c.precision, 10) + self.assertEqual(c.scale, 2) - c = curs.description[1] - self.assertEqual(c.name, 'hi') - self.assert_(c.type_code in psycopg2.STRING.values) - self.assert_(c.internal_size < 0) - self.assertEqual(c.precision, None) - self.assertEqual(c.scale, None) + c = curs.description[1] + self.assertEqual(c.name, 'hi') + self.assert_(c.type_code in psycopg2.STRING.values) + self.assert_(c.internal_size < 0) + self.assertEqual(c.precision, None) + self.assertEqual(c.scale, None) - c = curs.description[2] - self.assertEqual(c.name, 'now') - self.assert_(c.type_code in psycopg2.extensions.DATE.values) - self.assert_(c.internal_size > 0) - self.assertEqual(c.precision, None) - self.assertEqual(c.scale, None) + c = curs.description[2] + self.assertEqual(c.name, 'now') + self.assert_(c.type_code in psycopg2.extensions.DATE.values) + self.assert_(c.internal_size > 0) + self.assertEqual(c.precision, None) + self.assertEqual(c.scale, None) def test_description_extra_attribs(self): - curs = self.conn.cursor() - curs.execute(""" - create table testcol ( - pi decimal(10,2), - hi text) - """) - curs.execute("select oid from pg_class where relname = %s", ('testcol',)) - oid = curs.fetchone()[0] + with self.conn.cursor() as curs: + curs.execute(""" + create table testcol ( + pi decimal(10,2), + hi text) + """) + curs.execute("select oid from pg_class where relname = %s", ('testcol',)) + oid = curs.fetchone()[0] - curs.execute("insert into testcol values (3.14, 'hello')") - curs.execute("select hi, pi, 42 from testcol") - self.assertEqual(curs.description[0].table_oid, oid) - self.assertEqual(curs.description[0].table_column, 2) + curs.execute("insert into testcol values (3.14, 'hello')") + curs.execute("select hi, pi, 42 from testcol") + self.assertEqual(curs.description[0].table_oid, oid) + self.assertEqual(curs.description[0].table_column, 2) - self.assertEqual(curs.description[1].table_oid, oid) - self.assertEqual(curs.description[1].table_column, 1) + self.assertEqual(curs.description[1].table_oid, oid) + self.assertEqual(curs.description[1].table_column, 1) - self.assertEqual(curs.description[2].table_oid, None) - self.assertEqual(curs.description[2].table_column, None) + self.assertEqual(curs.description[2].table_oid, None) + self.assertEqual(curs.description[2].table_column, None) def test_pickle_description(self): - curs = self.conn.cursor() - curs.execute('SELECT 1 AS foo') - description = curs.description + with self.conn.cursor() as curs: + curs.execute('SELECT 1 AS foo') + description = curs.description - pickled = pickle.dumps(description, pickle.HIGHEST_PROTOCOL) - unpickled = pickle.loads(pickled) + pickled = pickle.dumps(description, pickle.HIGHEST_PROTOCOL) + unpickled = pickle.loads(pickled) - self.assertEqual(description, unpickled) + self.assertEqual(description, unpickled) @skip_before_postgres(8, 0) def test_named_cursor_stealing(self): # you can use a named cursor to iterate on a refcursor created # somewhere else - cur1 = self.conn.cursor() - cur1.execute("DECLARE test CURSOR WITHOUT HOLD " - " FOR SELECT generate_series(1,7)") + with self.conn.cursor() as cur1: + cur1.execute("DECLARE test CURSOR WITHOUT HOLD " + " FOR SELECT generate_series(1,7)") - cur2 = self.conn.cursor('test') - # can call fetch without execute - self.assertEqual((1,), cur2.fetchone()) - self.assertEqual([(2,), (3,), (4,)], cur2.fetchmany(3)) - self.assertEqual([(5,), (6,), (7,)], cur2.fetchall()) + with self.conn.cursor('test') as cur2: + # can call fetch without execute + self.assertEqual((1,), cur2.fetchone()) + self.assertEqual([(2,), (3,), (4,)], cur2.fetchmany(3)) + self.assertEqual([(5,), (6,), (7,)], cur2.fetchall()) @skip_before_postgres(8, 2) def test_named_noop_close(self): @@ -464,63 +462,63 @@ class CursorTests(ConnectingTestCase): @skip_before_postgres(8, 2) def test_stolen_named_cursor_close(self): - cur1 = self.conn.cursor() - cur1.execute("DECLARE test CURSOR WITHOUT HOLD " - " FOR SELECT generate_series(1,7)") - cur2 = self.conn.cursor('test') - cur2.close() + with self.conn.cursor() as cur1: + cur1.execute("DECLARE test CURSOR WITHOUT HOLD " + " FOR SELECT generate_series(1,7)") + cur2 = self.conn.cursor('test') + cur2.close() - cur1.execute("DECLARE test CURSOR WITHOUT HOLD " - " FOR SELECT generate_series(1,7)") - cur2 = self.conn.cursor('test') - cur2.close() + cur1.execute("DECLARE test CURSOR WITHOUT HOLD " + " FOR SELECT generate_series(1,7)") + cur2 = self.conn.cursor('test') + cur2.close() @skip_before_postgres(8, 0) def test_scroll(self): - cur = self.conn.cursor() - cur.execute("select generate_series(0,9)") - cur.scroll(2) - self.assertEqual(cur.fetchone(), (2,)) - cur.scroll(2) - self.assertEqual(cur.fetchone(), (5,)) - cur.scroll(2, mode='relative') - self.assertEqual(cur.fetchone(), (8,)) - cur.scroll(-1) - self.assertEqual(cur.fetchone(), (8,)) - cur.scroll(-2) - self.assertEqual(cur.fetchone(), (7,)) - cur.scroll(2, mode='absolute') - self.assertEqual(cur.fetchone(), (2,)) + with self.conn.cursor() as cur: + cur.execute("select generate_series(0,9)") + cur.scroll(2) + self.assertEqual(cur.fetchone(), (2,)) + cur.scroll(2) + self.assertEqual(cur.fetchone(), (5,)) + cur.scroll(2, mode='relative') + self.assertEqual(cur.fetchone(), (8,)) + cur.scroll(-1) + self.assertEqual(cur.fetchone(), (8,)) + cur.scroll(-2) + self.assertEqual(cur.fetchone(), (7,)) + cur.scroll(2, mode='absolute') + self.assertEqual(cur.fetchone(), (2,)) - # on the boundary - cur.scroll(0, mode='absolute') - self.assertEqual(cur.fetchone(), (0,)) - self.assertRaises((IndexError, psycopg2.ProgrammingError), - cur.scroll, -1, mode='absolute') - cur.scroll(0, mode='absolute') - self.assertRaises((IndexError, psycopg2.ProgrammingError), - cur.scroll, -1) + # on the boundary + cur.scroll(0, mode='absolute') + self.assertEqual(cur.fetchone(), (0,)) + self.assertRaises((IndexError, psycopg2.ProgrammingError), + cur.scroll, -1, mode='absolute') + cur.scroll(0, mode='absolute') + self.assertRaises((IndexError, psycopg2.ProgrammingError), + cur.scroll, -1) - cur.scroll(9, mode='absolute') - self.assertEqual(cur.fetchone(), (9,)) - self.assertRaises((IndexError, psycopg2.ProgrammingError), - cur.scroll, 10, mode='absolute') - cur.scroll(9, mode='absolute') - self.assertRaises((IndexError, psycopg2.ProgrammingError), - cur.scroll, 1) + cur.scroll(9, mode='absolute') + self.assertEqual(cur.fetchone(), (9,)) + self.assertRaises((IndexError, psycopg2.ProgrammingError), + cur.scroll, 10, mode='absolute') + cur.scroll(9, mode='absolute') + self.assertRaises((IndexError, psycopg2.ProgrammingError), + cur.scroll, 1) @skip_before_postgres(8, 0) def test_scroll_named(self): - cur = self.conn.cursor('tmp', scrollable=True) - cur.execute("select generate_series(0,9)") - cur.scroll(2) - self.assertEqual(cur.fetchone(), (2,)) - cur.scroll(2) - self.assertEqual(cur.fetchone(), (5,)) - cur.scroll(2, mode='relative') - self.assertEqual(cur.fetchone(), (8,)) - cur.scroll(9, mode='absolute') - self.assertEqual(cur.fetchone(), (9,)) + with self.conn.cursor('tmp', scrollable=True) as cur: + cur.execute("select generate_series(0,9)") + cur.scroll(2) + self.assertEqual(cur.fetchone(), (2,)) + cur.scroll(2) + self.assertEqual(cur.fetchone(), (5,)) + cur.scroll(2, mode='relative') + self.assertEqual(cur.fetchone(), (8,)) + cur.scroll(9, mode='absolute') + self.assertEqual(cur.fetchone(), (9,)) def test_bad_subclass(self): # check that we get an error message instead of a segfault @@ -531,14 +529,14 @@ class CursorTests(ConnectingTestCase): # I am stupid so not calling superclass init pass - cur = StupidCursor() - self.assertRaises(psycopg2.InterfaceError, cur.execute, 'select 1') - self.assertRaises(psycopg2.InterfaceError, cur.executemany, - 'select 1', []) + with StupidCursor() as cur: + self.assertRaises(psycopg2.InterfaceError, cur.execute, 'select 1') + self.assertRaises(psycopg2.InterfaceError, cur.executemany, + 'select 1', []) def test_callproc_badparam(self): - cur = self.conn.cursor() - self.assertRaises(TypeError, cur.callproc, 'lower', 42) + with self.conn.cursor() as cur: + self.assertRaises(TypeError, cur.callproc, 'lower', 42) # It would be inappropriate to test callproc's named parameters in the # DBAPI2.0 test section because they are a psycopg2 extension. @@ -551,32 +549,31 @@ class CursorTests(ConnectingTestCase): escaped_paramname = '"%s"' % paramname.replace('"', '""') procname = 'pg_temp.randall' - cur = self.conn.cursor() + with self.conn.cursor() as cur: + # Set up the temporary function + cur.execute(''' + CREATE FUNCTION %s(%s INT) + RETURNS INT AS + 'SELECT $1 * $1' + LANGUAGE SQL + ''' % (procname, escaped_paramname)) - # Set up the temporary function - cur.execute(''' - CREATE FUNCTION %s(%s INT) - RETURNS INT AS - 'SELECT $1 * $1' - LANGUAGE SQL - ''' % (procname, escaped_paramname)) + # Make sure callproc works right + cur.callproc(procname, {paramname: 2}) + self.assertEquals(cur.fetchone()[0], 4) - # Make sure callproc works right - cur.callproc(procname, {paramname: 2}) - self.assertEquals(cur.fetchone()[0], 4) - - # Make sure callproc fails right - failing_cases = [ - ({paramname: 2, 'foo': 'bar'}, psycopg2.ProgrammingError), - ({paramname: '2'}, psycopg2.ProgrammingError), - ({paramname: 'two'}, psycopg2.ProgrammingError), - ({u'bj\xc3rn': 2}, psycopg2.ProgrammingError), - ({3: 2}, TypeError), - ({self: 2}, TypeError), - ] - for parameter_sequence, exception in failing_cases: - self.assertRaises(exception, cur.callproc, procname, parameter_sequence) - self.conn.rollback() + # Make sure callproc fails right + failing_cases = [ + ({paramname: 2, 'foo': 'bar'}, psycopg2.ProgrammingError), + ({paramname: '2'}, psycopg2.ProgrammingError), + ({paramname: 'two'}, psycopg2.ProgrammingError), + ({u'bj\xc3rn': 2}, psycopg2.ProgrammingError), + ({3: 2}, TypeError), + ({self: 2}, TypeError), + ] + for parameter_sequence, exception in failing_cases: + self.assertRaises(exception, cur.callproc, procname, parameter_sequence) + self.conn.rollback() @skip_if_no_superuser @skip_if_windows @@ -638,17 +635,17 @@ class CursorTests(ConnectingTestCase): @skip_before_postgres(8, 2) def test_rowcount_on_executemany_returning(self): - cur = self.conn.cursor() - cur.execute("create table execmany(id serial primary key, data int)") - cur.executemany( - "insert into execmany (data) values (%s)", - [(i,) for i in range(4)]) - self.assertEqual(cur.rowcount, 4) + with self.conn.cursor() as cur: + cur.execute("create table execmany(id serial primary key, data int)") + cur.executemany( + "insert into execmany (data) values (%s)", + [(i,) for i in range(4)]) + self.assertEqual(cur.rowcount, 4) - cur.executemany( - "insert into execmany (data) values (%s) returning data", - [(i,) for i in range(5)]) - self.assertEqual(cur.rowcount, 5) + cur.executemany( + "insert into execmany (data) values (%s) returning data", + [(i,) for i in range(5)]) + self.assertEqual(cur.rowcount, 5) @skip_before_postgres(9) def test_pgresult_ptr(self): diff --git a/tests/test_dates.py b/tests/test_dates.py index b9aac695..829461b9 100755 --- a/tests/test_dates.py +++ b/tests/test_dates.py @@ -369,22 +369,22 @@ class DatetimeTests(ConnectingTestCase, CommonDatetimeTestsMixin): self.assertEqual(total_seconds(t), 1e-6) def test_interval_overflow(self): - cur = self.conn.cursor() - # hack a cursor to receive values too extreme to be represented - # but still I want an error, not a random number - psycopg2.extensions.register_type( - psycopg2.extensions.new_type( - psycopg2.STRING.values, 'WAT', psycopg2.extensions.INTERVAL), - cur) + with self.conn.cursor() as cur: + # hack a cursor to receive values too extreme to be represented + # but still I want an error, not a random number + psycopg2.extensions.register_type( + psycopg2.extensions.new_type( + psycopg2.STRING.values, 'WAT', psycopg2.extensions.INTERVAL), + cur) - def f(val): - cur.execute("select '%s'::text" % val) - return cur.fetchone()[0] + def f(val): + cur.execute("select '%s'::text" % val) + return cur.fetchone()[0] - self.assertRaises(OverflowError, f, '100000000000000000:00:00') - self.assertRaises(OverflowError, f, '00:100000000000000000:00:00') - self.assertRaises(OverflowError, f, '00:00:100000000000000000:00') - self.assertRaises(OverflowError, f, '00:00:00.100000000000000000') + self.assertRaises(OverflowError, f, '100000000000000000:00:00') + self.assertRaises(OverflowError, f, '00:100000000000000000:00:00') + self.assertRaises(OverflowError, f, '00:00:100000000000000000:00') + self.assertRaises(OverflowError, f, '00:00:00.100000000000000000') def test_adapt_infinity_tz(self): t = self.execute("select 'infinity'::timestamp") @@ -405,31 +405,31 @@ class DatetimeTests(ConnectingTestCase, CommonDatetimeTestsMixin): def test_redshift_day(self): # Redshift is reported returning 1 day interval as microsec (bug #558) - cur = self.conn.cursor() - psycopg2.extensions.register_type( - psycopg2.extensions.new_type( - psycopg2.STRING.values, 'WAT', psycopg2.extensions.INTERVAL), - cur) + with self.conn.cursor() as cur: + psycopg2.extensions.register_type( + psycopg2.extensions.new_type( + psycopg2.STRING.values, 'WAT', psycopg2.extensions.INTERVAL), + cur) - for s, v in [ - ('0', timedelta(0)), - ('1', timedelta(microseconds=1)), - ('-1', timedelta(microseconds=-1)), - ('1000000', timedelta(seconds=1)), - ('86400000000', timedelta(days=1)), - ('-86400000000', timedelta(days=-1)), - ]: - cur.execute("select %s::text", (s,)) - r = cur.fetchone()[0] - self.assertEqual(r, v, "%s -> %s != %s" % (s, r, v)) + for s, v in [ + ('0', timedelta(0)), + ('1', timedelta(microseconds=1)), + ('-1', timedelta(microseconds=-1)), + ('1000000', timedelta(seconds=1)), + ('86400000000', timedelta(days=1)), + ('-86400000000', timedelta(days=-1)), + ]: + cur.execute("select %s::text", (s,)) + r = cur.fetchone()[0] + self.assertEqual(r, v, "%s -> %s != %s" % (s, r, v)) @skip_before_postgres(8, 4) def test_interval_iso_8601_not_supported(self): # We may end up supporting, but no pressure for it - cur = self.conn.cursor() - cur.execute("set local intervalstyle to iso_8601") - cur.execute("select '1 day 2 hours'::interval") - self.assertRaises(psycopg2.NotSupportedError, cur.fetchone) + with self.conn.cursor() as cur: + cur.execute("set local intervalstyle to iso_8601") + cur.execute("select '1 day 2 hours'::interval") + self.assertRaises(psycopg2.NotSupportedError, cur.fetchone) @unittest.skipUnless( diff --git a/tests/test_extras_dictcursor.py b/tests/test_extras_dictcursor.py index d4bb12f5..80fd4cca 100755 --- a/tests/test_extras_dictcursor.py +++ b/tests/test_extras_dictcursor.py @@ -32,9 +32,9 @@ from .testutils import ConnectingTestCase, skip_before_postgres, \ class _DictCursorBase(ConnectingTestCase): def setUp(self): ConnectingTestCase.setUp(self) - curs = self.conn.cursor() - curs.execute("CREATE TEMPORARY TABLE ExtrasDictCursorTests (foo text)") - curs.execute("INSERT INTO ExtrasDictCursorTests VALUES ('bar')") + with self.conn.cursor() as curs: + curs.execute("CREATE TEMPORARY TABLE ExtrasDictCursorTests (foo text)") + curs.execute("INSERT INTO ExtrasDictCursorTests VALUES ('bar')") self.conn.commit() def _testIterRowNumber(self, curs): @@ -61,17 +61,20 @@ class _DictCursorBase(ConnectingTestCase): class ExtrasDictCursorTests(_DictCursorBase): """Test if DictCursor extension class works.""" + @skip_before_postgres(8, 2) def testDictConnCursorArgs(self): self.conn.close() self.conn = self.connect(connection_factory=psycopg2.extras.DictConnection) - cur = self.conn.cursor() - self.assert_(isinstance(cur, psycopg2.extras.DictCursor)) - self.assertEqual(cur.name, None) - # overridable - cur = self.conn.cursor('foo', - cursor_factory=psycopg2.extras.NamedTupleCursor) - self.assertEqual(cur.name, 'foo') - self.assert_(isinstance(cur, psycopg2.extras.NamedTupleCursor)) + with self.conn.cursor() as cur: + self.assert_(isinstance(cur, psycopg2.extras.DictCursor)) + self.assertEqual(cur.name, None) + # overridable + with self.conn.cursor( + 'foo', + cursor_factory=psycopg2.extras.NamedTupleCursor + ) as cur: + self.assertEqual(cur.name, 'foo') + self.assert_(isinstance(cur, psycopg2.extras.NamedTupleCursor)) def testDictCursorWithPlainCursorFetchOne(self): self._testWithPlainCursor(lambda curs: curs.fetchone()) @@ -99,13 +102,13 @@ class ExtrasDictCursorTests(_DictCursorBase): @skip_before_postgres(8, 0) def testDictCursorWithPlainCursorIterRowNumber(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) - self._testIterRowNumber(curs) + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs: + self._testIterRowNumber(curs) def _testWithPlainCursor(self, getter): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) - curs.execute("SELECT * FROM ExtrasDictCursorTests") - row = getter(curs) + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs: + curs.execute("SELECT * FROM ExtrasDictCursorTests") + row = getter(curs) self.failUnless(row['foo'] == 'bar') self.failUnless(row[0] == 'bar') return row @@ -130,80 +133,89 @@ class ExtrasDictCursorTests(_DictCursorBase): @skip_before_postgres(8, 2) def testDictCursorWithNamedCursorNotGreedy(self): - curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.DictCursor) - self._testNamedCursorNotGreedy(curs) + with self.conn.cursor( + 'tmp', + cursor_factory=psycopg2.extras.DictCursor + ) as curs: + self._testNamedCursorNotGreedy(curs) @skip_before_postgres(8, 0) def testDictCursorWithNamedCursorIterRowNumber(self): - curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.DictCursor) - self._testIterRowNumber(curs) + with self.conn.cursor( + 'tmp', + cursor_factory=psycopg2.extras.DictCursor + ) as curs: + self._testIterRowNumber(curs) def _testWithNamedCursor(self, getter): - curs = self.conn.cursor('aname', cursor_factory=psycopg2.extras.DictCursor) - curs.execute("SELECT * FROM ExtrasDictCursorTests") - row = getter(curs) - self.failUnless(row['foo'] == 'bar') - self.failUnless(row[0] == 'bar') + with self.conn.cursor( + 'aname', + cursor_factory=psycopg2.extras.DictCursor + ) as curs: + curs.execute("SELECT * FROM ExtrasDictCursorTests") + row = getter(curs) + self.failUnless(row['foo'] == 'bar') + self.failUnless(row[0] == 'bar') def testPickleDictRow(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) - curs.execute("select 10 as a, 20 as b") - r = curs.fetchone() - d = pickle.dumps(r) - r1 = pickle.loads(d) - self.assertEqual(r, r1) - self.assertEqual(r[0], r1[0]) - self.assertEqual(r[1], r1[1]) - self.assertEqual(r['a'], r1['a']) - self.assertEqual(r['b'], r1['b']) - self.assertEqual(r._index, r1._index) + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs: + curs.execute("select 10 as a, 20 as b") + r = curs.fetchone() + d = pickle.dumps(r) + r1 = pickle.loads(d) + self.assertEqual(r, r1) + self.assertEqual(r[0], r1[0]) + self.assertEqual(r[1], r1[1]) + self.assertEqual(r['a'], r1['a']) + self.assertEqual(r['b'], r1['b']) + self.assertEqual(r._index, r1._index) @skip_from_python(3) def test_iter_methods_2(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) - curs.execute("select 10 as a, 20 as b") - r = curs.fetchone() - self.assert_(isinstance(r.keys(), list)) - self.assertEqual(len(r.keys()), 2) - self.assert_(isinstance(r.values(), tuple)) # sic? - self.assertEqual(len(r.values()), 2) - self.assert_(isinstance(r.items(), list)) - self.assertEqual(len(r.items()), 2) + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs: + curs.execute("select 10 as a, 20 as b") + r = curs.fetchone() + self.assert_(isinstance(r.keys(), list)) + self.assertEqual(len(r.keys()), 2) + self.assert_(isinstance(r.values(), tuple)) # sic? + self.assertEqual(len(r.values()), 2) + self.assert_(isinstance(r.items(), list)) + self.assertEqual(len(r.items()), 2) - self.assert_(not isinstance(r.iterkeys(), list)) - self.assertEqual(len(list(r.iterkeys())), 2) - self.assert_(not isinstance(r.itervalues(), list)) - self.assertEqual(len(list(r.itervalues())), 2) - self.assert_(not isinstance(r.iteritems(), list)) - self.assertEqual(len(list(r.iteritems())), 2) + self.assert_(not isinstance(r.iterkeys(), list)) + self.assertEqual(len(list(r.iterkeys())), 2) + self.assert_(not isinstance(r.itervalues(), list)) + self.assertEqual(len(list(r.itervalues())), 2) + self.assert_(not isinstance(r.iteritems(), list)) + self.assertEqual(len(list(r.iteritems())), 2) @skip_before_python(3) def test_iter_methods_3(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) - curs.execute("select 10 as a, 20 as b") - r = curs.fetchone() - self.assert_(not isinstance(r.keys(), list)) - self.assertEqual(len(list(r.keys())), 2) - self.assert_(not isinstance(r.values(), list)) - self.assertEqual(len(list(r.values())), 2) - self.assert_(not isinstance(r.items(), list)) - self.assertEqual(len(list(r.items())), 2) + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs: + curs.execute("select 10 as a, 20 as b") + r = curs.fetchone() + self.assert_(not isinstance(r.keys(), list)) + self.assertEqual(len(list(r.keys())), 2) + self.assert_(not isinstance(r.values(), list)) + self.assertEqual(len(list(r.values())), 2) + self.assert_(not isinstance(r.items(), list)) + self.assertEqual(len(list(r.items())), 2) def test_order(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) - curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux") - r = curs.fetchone() - self.assertEqual(list(r), [5, 4, 33, 2]) - self.assertEqual(list(r.keys()), ['foo', 'bar', 'baz', 'qux']) - self.assertEqual(list(r.values()), [5, 4, 33, 2]) - self.assertEqual(list(r.items()), - [('foo', 5), ('bar', 4), ('baz', 33), ('qux', 2)]) + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs: + curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux") + r = curs.fetchone() + self.assertEqual(list(r), [5, 4, 33, 2]) + self.assertEqual(list(r.keys()), ['foo', 'bar', 'baz', 'qux']) + self.assertEqual(list(r.values()), [5, 4, 33, 2]) + self.assertEqual(list(r.items()), + [('foo', 5), ('bar', 4), ('baz', 33), ('qux', 2)]) - r1 = pickle.loads(pickle.dumps(r)) - self.assertEqual(list(r1), list(r)) - self.assertEqual(list(r1.keys()), list(r.keys())) - self.assertEqual(list(r1.values()), list(r.values())) - self.assertEqual(list(r1.items()), list(r.items())) + r1 = pickle.loads(pickle.dumps(r)) + self.assertEqual(list(r1), list(r)) + self.assertEqual(list(r1.keys()), list(r.keys())) + self.assertEqual(list(r1.values()), list(r.values())) + self.assertEqual(list(r1.items()), list(r.items())) @skip_from_python(3) def test_order_iter(self): @@ -223,10 +235,10 @@ class ExtrasDictCursorTests(_DictCursorBase): class ExtrasDictCursorRealTests(_DictCursorBase): def testRealMeansReal(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) - curs.execute("SELECT * FROM ExtrasDictCursorTests") - row = curs.fetchone() - self.assert_(isinstance(row, dict)) + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs: + curs.execute("SELECT * FROM ExtrasDictCursorTests") + row = curs.fetchone() + self.assert_(isinstance(row, dict)) def testDictCursorWithPlainCursorRealFetchOne(self): self._testWithPlainCursorReal(lambda curs: curs.fetchone()) @@ -248,24 +260,24 @@ class ExtrasDictCursorRealTests(_DictCursorBase): @skip_before_postgres(8, 0) def testDictCursorWithPlainCursorRealIterRowNumber(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) - self._testIterRowNumber(curs) + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs: + self._testIterRowNumber(curs) def _testWithPlainCursorReal(self, getter): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) - curs.execute("SELECT * FROM ExtrasDictCursorTests") - row = getter(curs) - self.failUnless(row['foo'] == 'bar') + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs: + curs.execute("SELECT * FROM ExtrasDictCursorTests") + row = getter(curs) + self.failUnless(row['foo'] == 'bar') def testPickleRealDictRow(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) - curs.execute("select 10 as a, 20 as b") - r = curs.fetchone() - d = pickle.dumps(r) - r1 = pickle.loads(d) - self.assertEqual(r, r1) - self.assertEqual(r['a'], r1['a']) - self.assertEqual(r['b'], r1['b']) + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs: + curs.execute("select 10 as a, 20 as b") + r = curs.fetchone() + d = pickle.dumps(r) + r1 = pickle.loads(d) + self.assertEqual(r, r1) + self.assertEqual(r['a'], r1['a']) + self.assertEqual(r['b'], r1['b']) def testDictCursorRealWithNamedCursorFetchOne(self): self._testWithNamedCursorReal(lambda curs: curs.fetchone()) @@ -287,26 +299,34 @@ class ExtrasDictCursorRealTests(_DictCursorBase): @skip_before_postgres(8, 2) def testDictCursorRealWithNamedCursorNotGreedy(self): - curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.RealDictCursor) - self._testNamedCursorNotGreedy(curs) + with self.conn.cursor( + 'tmp', + cursor_factory=psycopg2.extras.RealDictCursor + ) as curs: + self._testNamedCursorNotGreedy(curs) @skip_before_postgres(8, 0) def testDictCursorRealWithNamedCursorIterRowNumber(self): - curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.RealDictCursor) - self._testIterRowNumber(curs) + with self.conn.cursor( + 'tmp', + cursor_factory=psycopg2.extras.RealDictCursor + ) as curs: + self._testIterRowNumber(curs) def _testWithNamedCursorReal(self, getter): - curs = self.conn.cursor('aname', - cursor_factory=psycopg2.extras.RealDictCursor) - curs.execute("SELECT * FROM ExtrasDictCursorTests") - row = getter(curs) - self.failUnless(row['foo'] == 'bar') + with self.conn.cursor( + 'aname', + cursor_factory=psycopg2.extras.RealDictCursor + ) as curs: + curs.execute("SELECT * FROM ExtrasDictCursorTests") + row = getter(curs) + self.failUnless(row['foo'] == 'bar') @skip_from_python(3) def test_iter_methods_2(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) - curs.execute("select 10 as a, 20 as b") - r = curs.fetchone() + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs: + curs.execute("select 10 as a, 20 as b") + r = curs.fetchone() self.assert_(isinstance(r.keys(), list)) self.assertEqual(len(r.keys()), 2) self.assert_(isinstance(r.values(), list)) @@ -323,9 +343,9 @@ class ExtrasDictCursorRealTests(_DictCursorBase): @skip_before_python(3) def test_iter_methods_3(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) - curs.execute("select 10 as a, 20 as b") - r = curs.fetchone() + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs: + curs.execute("select 10 as a, 20 as b") + r = curs.fetchone() self.assert_(not isinstance(r.keys(), list)) self.assertEqual(len(list(r.keys())), 2) self.assert_(not isinstance(r.values(), list)) @@ -334,9 +354,9 @@ class ExtrasDictCursorRealTests(_DictCursorBase): self.assertEqual(len(list(r.items())), 2) def test_order(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) - curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux") - r = curs.fetchone() + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs: + curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux") + r = curs.fetchone() self.assertEqual(list(r), ['foo', 'bar', 'baz', 'qux']) self.assertEqual(list(r.keys()), ['foo', 'bar', 'baz', 'qux']) self.assertEqual(list(r.values()), [5, 4, 33, 2]) @@ -351,9 +371,9 @@ class ExtrasDictCursorRealTests(_DictCursorBase): @skip_from_python(3) def test_order_iter(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) - curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux") - r = curs.fetchone() + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs: + curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux") + r = curs.fetchone() self.assertEqual(list(r.iterkeys()), ['foo', 'bar', 'baz', 'qux']) self.assertEqual(list(r.itervalues()), [5, 4, 33, 2]) self.assertEqual(list(r.iteritems()), @@ -365,9 +385,9 @@ class ExtrasDictCursorRealTests(_DictCursorBase): self.assertEqual(list(r1.iteritems()), list(r.iteritems())) def test_pop(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) - curs.execute("select 1 as a, 2 as b, 3 as c") - r = curs.fetchone() + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs: + curs.execute("select 1 as a, 2 as b, 3 as c") + r = curs.fetchone() self.assertEqual(r.pop('b'), 2) self.assertEqual(list(r), ['a', 'c']) self.assertEqual(list(r.keys()), ['a', 'c']) @@ -378,9 +398,9 @@ class ExtrasDictCursorRealTests(_DictCursorBase): self.assertRaises(KeyError, r.pop, 'b') def test_mod(self): - curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) - curs.execute("select 1 as a, 2 as b, 3 as c") - r = curs.fetchone() + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs: + curs.execute("select 1 as a, 2 as b, 3 as c") + r = curs.fetchone() r['d'] = 4 self.assertEqual(list(r), ['a', 'b', 'c', 'd']) self.assertEqual(list(r.keys()), ['a', 'b', 'c', 'd']) @@ -399,137 +419,141 @@ class NamedTupleCursorTest(ConnectingTestCase): ConnectingTestCase.setUp(self) self.conn = self.connect(connection_factory=NamedTupleConnection) - curs = self.conn.cursor() - curs.execute("CREATE TEMPORARY TABLE nttest (i int, s text)") - curs.execute("INSERT INTO nttest VALUES (1, 'foo')") - curs.execute("INSERT INTO nttest VALUES (2, 'bar')") - curs.execute("INSERT INTO nttest VALUES (3, 'baz')") + with self.conn.cursor() as curs: + curs.execute("CREATE TEMPORARY TABLE nttest (i int, s text)") + curs.execute("INSERT INTO nttest VALUES (1, 'foo')") + curs.execute("INSERT INTO nttest VALUES (2, 'bar')") + curs.execute("INSERT INTO nttest VALUES (3, 'baz')") self.conn.commit() + @skip_before_postgres(8, 2) def test_cursor_args(self): - cur = self.conn.cursor('foo', cursor_factory=psycopg2.extras.DictCursor) - self.assertEqual(cur.name, 'foo') - self.assert_(isinstance(cur, psycopg2.extras.DictCursor)) + with self.conn.cursor( + 'foo', + cursor_factory=psycopg2.extras.DictCursor + ) as cur: + self.assertEqual(cur.name, 'foo') + self.assert_(isinstance(cur, psycopg2.extras.DictCursor)) def test_fetchone(self): - curs = self.conn.cursor() - curs.execute("select * from nttest order by 1") - t = curs.fetchone() - self.assertEqual(t[0], 1) - self.assertEqual(t.i, 1) - self.assertEqual(t[1], 'foo') - self.assertEqual(t.s, 'foo') - self.assertEqual(curs.rownumber, 1) - self.assertEqual(curs.rowcount, 3) + with self.conn.cursor() as curs: + curs.execute("select * from nttest order by 1") + t = curs.fetchone() + self.assertEqual(t[0], 1) + self.assertEqual(t.i, 1) + self.assertEqual(t[1], 'foo') + self.assertEqual(t.s, 'foo') + self.assertEqual(curs.rownumber, 1) + self.assertEqual(curs.rowcount, 3) def test_fetchmany_noarg(self): - curs = self.conn.cursor() - curs.arraysize = 2 - curs.execute("select * from nttest order by 1") - res = curs.fetchmany() - self.assertEqual(2, len(res)) - self.assertEqual(res[0].i, 1) - self.assertEqual(res[0].s, 'foo') - self.assertEqual(res[1].i, 2) - self.assertEqual(res[1].s, 'bar') - self.assertEqual(curs.rownumber, 2) - self.assertEqual(curs.rowcount, 3) + with self.conn.cursor() as curs: + curs.arraysize = 2 + curs.execute("select * from nttest order by 1") + res = curs.fetchmany() + self.assertEqual(2, len(res)) + self.assertEqual(res[0].i, 1) + self.assertEqual(res[0].s, 'foo') + self.assertEqual(res[1].i, 2) + self.assertEqual(res[1].s, 'bar') + self.assertEqual(curs.rownumber, 2) + self.assertEqual(curs.rowcount, 3) def test_fetchmany(self): - curs = self.conn.cursor() - curs.execute("select * from nttest order by 1") - res = curs.fetchmany(2) - self.assertEqual(2, len(res)) - self.assertEqual(res[0].i, 1) - self.assertEqual(res[0].s, 'foo') - self.assertEqual(res[1].i, 2) - self.assertEqual(res[1].s, 'bar') - self.assertEqual(curs.rownumber, 2) - self.assertEqual(curs.rowcount, 3) + with self.conn.cursor() as curs: + curs.execute("select * from nttest order by 1") + res = curs.fetchmany(2) + self.assertEqual(2, len(res)) + self.assertEqual(res[0].i, 1) + self.assertEqual(res[0].s, 'foo') + self.assertEqual(res[1].i, 2) + self.assertEqual(res[1].s, 'bar') + self.assertEqual(curs.rownumber, 2) + self.assertEqual(curs.rowcount, 3) def test_fetchall(self): - curs = self.conn.cursor() - curs.execute("select * from nttest order by 1") - res = curs.fetchall() - self.assertEqual(3, len(res)) - self.assertEqual(res[0].i, 1) - self.assertEqual(res[0].s, 'foo') - self.assertEqual(res[1].i, 2) - self.assertEqual(res[1].s, 'bar') - self.assertEqual(res[2].i, 3) - self.assertEqual(res[2].s, 'baz') - self.assertEqual(curs.rownumber, 3) - self.assertEqual(curs.rowcount, 3) + with self.conn.cursor() as curs: + curs.execute("select * from nttest order by 1") + res = curs.fetchall() + self.assertEqual(3, len(res)) + self.assertEqual(res[0].i, 1) + self.assertEqual(res[0].s, 'foo') + self.assertEqual(res[1].i, 2) + self.assertEqual(res[1].s, 'bar') + self.assertEqual(res[2].i, 3) + self.assertEqual(res[2].s, 'baz') + self.assertEqual(curs.rownumber, 3) + self.assertEqual(curs.rowcount, 3) def test_executemany(self): - curs = self.conn.cursor() - curs.executemany("delete from nttest where i = %s", - [(1,), (2,)]) - curs.execute("select * from nttest order by 1") - res = curs.fetchall() + with self.conn.cursor() as curs: + curs.executemany("delete from nttest where i = %s", + [(1,), (2,)]) + curs.execute("select * from nttest order by 1") + res = curs.fetchall() self.assertEqual(1, len(res)) self.assertEqual(res[0].i, 3) self.assertEqual(res[0].s, 'baz') def test_iter(self): - curs = self.conn.cursor() - curs.execute("select * from nttest order by 1") - i = iter(curs) - self.assertEqual(curs.rownumber, 0) + with self.conn.cursor() as curs: + curs.execute("select * from nttest order by 1") + i = iter(curs) + self.assertEqual(curs.rownumber, 0) - t = next(i) - self.assertEqual(t.i, 1) - self.assertEqual(t.s, 'foo') - self.assertEqual(curs.rownumber, 1) - self.assertEqual(curs.rowcount, 3) + t = next(i) + self.assertEqual(t.i, 1) + self.assertEqual(t.s, 'foo') + self.assertEqual(curs.rownumber, 1) + self.assertEqual(curs.rowcount, 3) - t = next(i) - self.assertEqual(t.i, 2) - self.assertEqual(t.s, 'bar') - self.assertEqual(curs.rownumber, 2) - self.assertEqual(curs.rowcount, 3) + t = next(i) + self.assertEqual(t.i, 2) + self.assertEqual(t.s, 'bar') + self.assertEqual(curs.rownumber, 2) + self.assertEqual(curs.rowcount, 3) - t = next(i) - self.assertEqual(t.i, 3) - self.assertEqual(t.s, 'baz') - self.assertRaises(StopIteration, next, i) - self.assertEqual(curs.rownumber, 3) - self.assertEqual(curs.rowcount, 3) + t = next(i) + self.assertEqual(t.i, 3) + self.assertEqual(t.s, 'baz') + self.assertRaises(StopIteration, next, i) + self.assertEqual(curs.rownumber, 3) + self.assertEqual(curs.rowcount, 3) def test_record_updated(self): - curs = self.conn.cursor() - curs.execute("select 1 as foo;") - r = curs.fetchone() - self.assertEqual(r.foo, 1) + with self.conn.cursor() as curs: + curs.execute("select 1 as foo;") + r = curs.fetchone() + self.assertEqual(r.foo, 1) - curs.execute("select 2 as bar;") - r = curs.fetchone() - self.assertEqual(r.bar, 2) - self.assertRaises(AttributeError, getattr, r, 'foo') + curs.execute("select 2 as bar;") + r = curs.fetchone() + self.assertEqual(r.bar, 2) + self.assertRaises(AttributeError, getattr, r, 'foo') def test_no_result_no_surprise(self): - curs = self.conn.cursor() - curs.execute("update nttest set s = s") - self.assertRaises(psycopg2.ProgrammingError, curs.fetchone) + with self.conn.cursor() as curs: + curs.execute("update nttest set s = s") + self.assertRaises(psycopg2.ProgrammingError, curs.fetchone) - curs.execute("update nttest set s = s") - self.assertRaises(psycopg2.ProgrammingError, curs.fetchall) + curs.execute("update nttest set s = s") + self.assertRaises(psycopg2.ProgrammingError, curs.fetchall) def test_bad_col_names(self): - curs = self.conn.cursor() - curs.execute('select 1 as "foo.bar_baz", 2 as "?column?", 3 as "3"') - rv = curs.fetchone() - self.assertEqual(rv.foo_bar_baz, 1) - self.assertEqual(rv.f_column_, 2) - self.assertEqual(rv.f3, 3) + with self.conn.cursor() as curs: + curs.execute('select 1 as "foo.bar_baz", 2 as "?column?", 3 as "3"') + rv = curs.fetchone() + self.assertEqual(rv.foo_bar_baz, 1) + self.assertEqual(rv.f_column_, 2) + self.assertEqual(rv.f3, 3) @skip_before_python(3) @skip_before_postgres(8) def test_nonascii_name(self): - curs = self.conn.cursor() - curs.execute('select 1 as \xe5h\xe9') - rv = curs.fetchone() - self.assertEqual(getattr(rv, '\xe5h\xe9'), 1) + with self.conn.cursor() as curs: + curs.execute('select 1 as \xe5h\xe9') + rv = curs.fetchone() + self.assertEqual(getattr(rv, '\xe5h\xe9'), 1) def test_minimal_generation(self): # Instrument the class to verify it gets called the minimum number of times. @@ -543,91 +567,92 @@ class NamedTupleCursorTest(ConnectingTestCase): NamedTupleCursor._make_nt = f_patched try: - curs = self.conn.cursor() - curs.execute("select * from nttest order by 1") - curs.fetchone() - curs.fetchone() - curs.fetchone() - self.assertEqual(1, calls[0]) + with self.conn.cursor() as curs: + curs.execute("select * from nttest order by 1") + curs.fetchone() + curs.fetchone() + curs.fetchone() + self.assertEqual(1, calls[0]) - curs.execute("select * from nttest order by 1") - curs.fetchone() - curs.fetchall() - self.assertEqual(2, calls[0]) + curs.execute("select * from nttest order by 1") + curs.fetchone() + curs.fetchall() + self.assertEqual(2, calls[0]) - curs.execute("select * from nttest order by 1") - curs.fetchone() - curs.fetchmany(1) - self.assertEqual(3, calls[0]) + curs.execute("select * from nttest order by 1") + curs.fetchone() + curs.fetchmany(1) + self.assertEqual(3, calls[0]) finally: NamedTupleCursor._make_nt = f_orig @skip_before_postgres(8, 0) def test_named(self): - curs = self.conn.cursor('tmp') - curs.execute("""select i from generate_series(0,9) i""") - recs = [] - recs.extend(curs.fetchmany(5)) - recs.append(curs.fetchone()) - recs.extend(curs.fetchall()) - self.assertEqual(list(range(10)), [t.i for t in recs]) + with self.conn.cursor('tmp') as curs: + curs.execute("""select i from generate_series(0,9) i""") + recs = [] + recs.extend(curs.fetchmany(5)) + recs.append(curs.fetchone()) + recs.extend(curs.fetchall()) + self.assertEqual(list(range(10)), [t.i for t in recs]) def test_named_fetchone(self): - curs = self.conn.cursor('tmp') - curs.execute("""select 42 as i""") - t = curs.fetchone() - self.assertEqual(t.i, 42) + with self.conn.cursor('tmp') as curs: + curs.execute("""select 42 as i""") + t = curs.fetchone() + self.assertEqual(t.i, 42) def test_named_fetchmany(self): - curs = self.conn.cursor('tmp') - curs.execute("""select 42 as i""") - recs = curs.fetchmany(10) - self.assertEqual(recs[0].i, 42) + with self.conn.cursor('tmp') as curs: + curs.execute("""select 42 as i""") + recs = curs.fetchmany(10) + self.assertEqual(recs[0].i, 42) def test_named_fetchall(self): - curs = self.conn.cursor('tmp') - curs.execute("""select 42 as i""") - recs = curs.fetchall() - self.assertEqual(recs[0].i, 42) + with self.conn.cursor('tmp') as curs: + curs.execute("""select 42 as i""") + recs = curs.fetchall() + self.assertEqual(recs[0].i, 42) @skip_before_postgres(8, 2) def test_not_greedy(self): - curs = self.conn.cursor('tmp') - curs.itersize = 2 - curs.execute("""select clock_timestamp() as ts from generate_series(1,3)""") - recs = [] - for t in curs: - time.sleep(0.01) - recs.append(t) + with self.conn.cursor('tmp') as curs: + curs.itersize = 2 + curs.execute( + """select clock_timestamp() as ts from generate_series(1,3)""") + recs = [] + for t in curs: + time.sleep(0.01) + recs.append(t) - # check that the dataset was not fetched in a single gulp - self.assert_(recs[1].ts - recs[0].ts < timedelta(seconds=0.005)) - self.assert_(recs[2].ts - recs[1].ts > timedelta(seconds=0.0099)) + # check that the dataset was not fetched in a single gulp + self.assert_(recs[1].ts - recs[0].ts < timedelta(seconds=0.005)) + self.assert_(recs[2].ts - recs[1].ts > timedelta(seconds=0.0099)) @skip_before_postgres(8, 0) def test_named_rownumber(self): - curs = self.conn.cursor('tmp') - # Only checking for dataset < itersize: - # see CursorTests.test_iter_named_cursor_rownumber - curs.itersize = 4 - curs.execute("""select * from generate_series(1,3)""") - for i, t in enumerate(curs): - self.assertEqual(i + 1, curs.rownumber) + with self.conn.cursor('tmp') as curs: + # Only checking for dataset < itersize: + # see CursorTests.test_iter_named_cursor_rownumber + curs.itersize = 4 + curs.execute("""select * from generate_series(1,3)""") + for i, t in enumerate(curs): + self.assertEqual(i + 1, curs.rownumber) def test_cache(self): NamedTupleCursor._cached_make_nt.cache_clear() - curs = self.conn.cursor() - curs.execute("select 10 as a, 20 as b") - r1 = curs.fetchone() - curs.execute("select 10 as a, 20 as c") - r2 = curs.fetchone() + with self.conn.cursor() as curs: + curs.execute("select 10 as a, 20 as b") + r1 = curs.fetchone() + curs.execute("select 10 as a, 20 as c") + r2 = curs.fetchone() # Get a new cursor to check that the cache works across multiple ones - curs = self.conn.cursor() - curs.execute("select 10 as a, 30 as b") - r3 = curs.fetchone() + with self.conn.cursor() as curs: + curs.execute("select 10 as a, 30 as b") + r3 = curs.fetchone() self.assert_(type(r1) is type(r3)) self.assert_(type(r1) is not type(r2)) @@ -643,20 +668,20 @@ class NamedTupleCursorTest(ConnectingTestCase): lru_cache(8)(NamedTupleCursor._cached_make_nt.__wrapped__) try: recs = [] - curs = self.conn.cursor() - for i in range(10): - curs.execute("select 1 as f%s" % i) - recs.append(curs.fetchone()) + with self.conn.cursor() as curs: + for i in range(10): + curs.execute("select 1 as f%s" % i) + recs.append(curs.fetchone()) - # Still in cache - curs.execute("select 1 as f9") - rec = curs.fetchone() - self.assert_(any(type(r) is type(rec) for r in recs)) + # Still in cache + curs.execute("select 1 as f9") + rec = curs.fetchone() + self.assert_(any(type(r) is type(rec) for r in recs)) - # Gone from cache - curs.execute("select 1 as f0") - rec = curs.fetchone() - self.assert_(all(type(r) is not type(rec) for r in recs)) + # Gone from cache + curs.execute("select 1 as f0") + rec = curs.fetchone() + self.assert_(all(type(r) is not type(rec) for r in recs)) finally: NamedTupleCursor._cached_make_nt = old_func diff --git a/tests/test_fast_executemany.py b/tests/test_fast_executemany.py index eaba029c..8746303b 100755 --- a/tests/test_fast_executemany.py +++ b/tests/test_fast_executemany.py @@ -46,219 +46,219 @@ class TestPaginate(unittest.TestCase): class FastExecuteTestMixin(object): def setUp(self): super(FastExecuteTestMixin, self).setUp() - cur = self.conn.cursor() - cur.execute("""create table testfast ( - id serial primary key, date date, val int, data text)""") + with self.conn.cursor() as cur: + cur.execute("""create table testfast ( + id serial primary key, date date, val int, data text)""") class TestExecuteBatch(FastExecuteTestMixin, testutils.ConnectingTestCase): def test_empty(self): - cur = self.conn.cursor() - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, val) values (%s, %s)", - []) - cur.execute("select * from testfast order by id") - self.assertEqual(cur.fetchall(), []) + with self.conn.cursor() as cur: + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, val) values (%s, %s)", + []) + cur.execute("select * from testfast order by id") + self.assertEqual(cur.fetchall(), []) def test_one(self): - cur = self.conn.cursor() - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, val) values (%s, %s)", - iter([(1, 10)])) - cur.execute("select id, val from testfast order by id") - self.assertEqual(cur.fetchall(), [(1, 10)]) + with self.conn.cursor() as cur: + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, val) values (%s, %s)", + iter([(1, 10)])) + cur.execute("select id, val from testfast order by id") + self.assertEqual(cur.fetchall(), [(1, 10)]) def test_tuples(self): - cur = self.conn.cursor() - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, date, val) values (%s, %s, %s)", - ((i, date(2017, 1, i + 1), i * 10) for i in range(10))) - cur.execute("select id, date, val from testfast order by id") - self.assertEqual(cur.fetchall(), - [(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) + with self.conn.cursor() as cur: + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, date, val) values (%s, %s, %s)", + ((i, date(2017, 1, i + 1), i * 10) for i in range(10))) + cur.execute("select id, date, val from testfast order by id") + self.assertEqual(cur.fetchall(), + [(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) def test_many(self): - cur = self.conn.cursor() - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, val) values (%s, %s)", - ((i, i * 10) for i in range(1000))) - cur.execute("select id, val from testfast order by id") - self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) + with self.conn.cursor() as cur: + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, val) values (%s, %s)", + ((i, i * 10) for i in range(1000))) + cur.execute("select id, val from testfast order by id") + self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) def test_composed(self): - cur = self.conn.cursor() - psycopg2.extras.execute_batch(cur, - sql.SQL("insert into {0} (id, val) values (%s, %s)") - .format(sql.Identifier('testfast')), - ((i, i * 10) for i in range(1000))) - cur.execute("select id, val from testfast order by id") - self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) + with self.conn.cursor() as cur: + psycopg2.extras.execute_batch(cur, + sql.SQL("insert into {0} (id, val) values (%s, %s)") + .format(sql.Identifier('testfast')), + ((i, i * 10) for i in range(1000))) + cur.execute("select id, val from testfast order by id") + self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) def test_pages(self): - cur = self.conn.cursor() - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, val) values (%s, %s)", - ((i, i * 10) for i in range(25)), - page_size=10) + with self.conn.cursor() as cur: + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, val) values (%s, %s)", + ((i, i * 10) for i in range(25)), + page_size=10) - # last command was 5 statements - self.assertEqual(sum(c == u';' for c in cur.query.decode('ascii')), 4) + # last command was 5 statements + self.assertEqual(sum(c == u';' for c in cur.query.decode('ascii')), 4) - cur.execute("select id, val from testfast order by id") - self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)]) + cur.execute("select id, val from testfast order by id") + self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)]) @testutils.skip_before_postgres(8, 0) def test_unicode(self): - cur = self.conn.cursor() - ext.register_type(ext.UNICODE, cur) - snowman = u"\u2603" + with self.conn.cursor() as cur: + ext.register_type(ext.UNICODE, cur) + snowman = u"\u2603" - # unicode in statement - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman, - [(1, 'x')]) - cur.execute("select id, data from testfast where id = 1") - self.assertEqual(cur.fetchone(), (1, 'x')) + # unicode in statement + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman, + [(1, 'x')]) + cur.execute("select id, data from testfast where id = 1") + self.assertEqual(cur.fetchone(), (1, 'x')) - # unicode in data - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, data) values (%s, %s)", - [(2, snowman)]) - cur.execute("select id, data from testfast where id = 2") - self.assertEqual(cur.fetchone(), (2, snowman)) + # unicode in data + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, data) values (%s, %s)", + [(2, snowman)]) + cur.execute("select id, data from testfast where id = 2") + self.assertEqual(cur.fetchone(), (2, snowman)) - # unicode in both - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman, - [(3, snowman)]) - cur.execute("select id, data from testfast where id = 3") - self.assertEqual(cur.fetchone(), (3, snowman)) + # unicode in both + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman, + [(3, snowman)]) + cur.execute("select id, data from testfast where id = 3") + self.assertEqual(cur.fetchone(), (3, snowman)) @testutils.skip_before_postgres(8, 2) class TestExecuteValues(FastExecuteTestMixin, testutils.ConnectingTestCase): def test_empty(self): - cur = self.conn.cursor() - psycopg2.extras.execute_values(cur, - "insert into testfast (id, val) values %s", - []) - cur.execute("select * from testfast order by id") - self.assertEqual(cur.fetchall(), []) + with self.conn.cursor() as cur: + psycopg2.extras.execute_values(cur, + "insert into testfast (id, val) values %s", + []) + cur.execute("select * from testfast order by id") + self.assertEqual(cur.fetchall(), []) def test_one(self): - cur = self.conn.cursor() - psycopg2.extras.execute_values(cur, - "insert into testfast (id, val) values %s", - iter([(1, 10)])) - cur.execute("select id, val from testfast order by id") - self.assertEqual(cur.fetchall(), [(1, 10)]) + with self.conn.cursor() as cur: + psycopg2.extras.execute_values(cur, + "insert into testfast (id, val) values %s", + iter([(1, 10)])) + cur.execute("select id, val from testfast order by id") + self.assertEqual(cur.fetchall(), [(1, 10)]) def test_tuples(self): - cur = self.conn.cursor() - psycopg2.extras.execute_values(cur, - "insert into testfast (id, date, val) values %s", - ((i, date(2017, 1, i + 1), i * 10) for i in range(10))) - cur.execute("select id, date, val from testfast order by id") - self.assertEqual(cur.fetchall(), - [(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) + with self.conn.cursor() as cur: + psycopg2.extras.execute_values(cur, + "insert into testfast (id, date, val) values %s", + ((i, date(2017, 1, i + 1), i * 10) for i in range(10))) + cur.execute("select id, date, val from testfast order by id") + self.assertEqual(cur.fetchall(), + [(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) def test_dicts(self): - cur = self.conn.cursor() - psycopg2.extras.execute_values(cur, - "insert into testfast (id, date, val) values %s", - (dict(id=i, date=date(2017, 1, i + 1), val=i * 10, foo="bar") - for i in range(10)), - template='(%(id)s, %(date)s, %(val)s)') - cur.execute("select id, date, val from testfast order by id") - self.assertEqual(cur.fetchall(), - [(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) + with self.conn.cursor() as cur: + psycopg2.extras.execute_values(cur, + "insert into testfast (id, date, val) values %s", + (dict(id=i, date=date(2017, 1, i + 1), val=i * 10, foo="bar") + for i in range(10)), + template='(%(id)s, %(date)s, %(val)s)') + cur.execute("select id, date, val from testfast order by id") + self.assertEqual(cur.fetchall(), + [(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) def test_many(self): - cur = self.conn.cursor() - psycopg2.extras.execute_values(cur, - "insert into testfast (id, val) values %s", - ((i, i * 10) for i in range(1000))) - cur.execute("select id, val from testfast order by id") - self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) + with self.conn.cursor() as cur: + psycopg2.extras.execute_values(cur, + "insert into testfast (id, val) values %s", + ((i, i * 10) for i in range(1000))) + cur.execute("select id, val from testfast order by id") + self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) def test_composed(self): - cur = self.conn.cursor() - psycopg2.extras.execute_values(cur, - sql.SQL("insert into {0} (id, val) values %s") - .format(sql.Identifier('testfast')), - ((i, i * 10) for i in range(1000))) - cur.execute("select id, val from testfast order by id") - self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) + with self.conn.cursor() as cur: + psycopg2.extras.execute_values(cur, + sql.SQL("insert into {0} (id, val) values %s") + .format(sql.Identifier('testfast')), + ((i, i * 10) for i in range(1000))) + cur.execute("select id, val from testfast order by id") + self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) def test_pages(self): - cur = self.conn.cursor() - psycopg2.extras.execute_values(cur, - "insert into testfast (id, val) values %s", - ((i, i * 10) for i in range(25)), - page_size=10) + with self.conn.cursor() as cur: + psycopg2.extras.execute_values(cur, + "insert into testfast (id, val) values %s", + ((i, i * 10) for i in range(25)), + page_size=10) - # last statement was 5 tuples (one parens is for the fields list) - self.assertEqual(sum(c == '(' for c in cur.query.decode('ascii')), 6) + # last statement was 5 tuples (one parens is for the fields list) + self.assertEqual(sum(c == '(' for c in cur.query.decode('ascii')), 6) - cur.execute("select id, val from testfast order by id") - self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)]) + cur.execute("select id, val from testfast order by id") + self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)]) def test_unicode(self): - cur = self.conn.cursor() - ext.register_type(ext.UNICODE, cur) - snowman = u"\u2603" + with self.conn.cursor() as cur: + ext.register_type(ext.UNICODE, cur) + snowman = u"\u2603" - # unicode in statement - psycopg2.extras.execute_values(cur, - "insert into testfast (id, data) values %%s -- %s" % snowman, - [(1, 'x')]) - cur.execute("select id, data from testfast where id = 1") - self.assertEqual(cur.fetchone(), (1, 'x')) + # unicode in statement + psycopg2.extras.execute_values(cur, + "insert into testfast (id, data) values %%s -- %s" % snowman, + [(1, 'x')]) + cur.execute("select id, data from testfast where id = 1") + self.assertEqual(cur.fetchone(), (1, 'x')) - # unicode in data - psycopg2.extras.execute_values(cur, - "insert into testfast (id, data) values %s", - [(2, snowman)]) - cur.execute("select id, data from testfast where id = 2") - self.assertEqual(cur.fetchone(), (2, snowman)) + # unicode in data + psycopg2.extras.execute_values(cur, + "insert into testfast (id, data) values %s", + [(2, snowman)]) + cur.execute("select id, data from testfast where id = 2") + self.assertEqual(cur.fetchone(), (2, snowman)) - # unicode in both - psycopg2.extras.execute_values(cur, - "insert into testfast (id, data) values %%s -- %s" % snowman, - [(3, snowman)]) - cur.execute("select id, data from testfast where id = 3") - self.assertEqual(cur.fetchone(), (3, snowman)) + # unicode in both + psycopg2.extras.execute_values(cur, + "insert into testfast (id, data) values %%s -- %s" % snowman, + [(3, snowman)]) + cur.execute("select id, data from testfast where id = 3") + self.assertEqual(cur.fetchone(), (3, snowman)) def test_returning(self): - cur = self.conn.cursor() - result = psycopg2.extras.execute_values(cur, - "insert into testfast (id, val) values %s returning id", - ((i, i * 10) for i in range(25)), - page_size=10, fetch=True) - # result contains all returned pages - self.assertEqual([r[0] for r in result], list(range(25))) + with self.conn.cursor() as cur: + result = psycopg2.extras.execute_values(cur, + "insert into testfast (id, val) values %s returning id", + ((i, i * 10) for i in range(25)), + page_size=10, fetch=True) + # result contains all returned pages + self.assertEqual([r[0] for r in result], list(range(25))) def test_invalid_sql(self): - cur = self.conn.cursor() - self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, - "insert", []) - self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, - "insert %s and %s", []) - self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, - "insert %f", []) - self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, - "insert %f %s", []) + with self.conn.cursor() as cur: + self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, + "insert", []) + self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, + "insert %s and %s", []) + self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, + "insert %f", []) + self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, + "insert %f %s", []) def test_percent_escape(self): - cur = self.conn.cursor() - psycopg2.extras.execute_values(cur, - "insert into testfast (id, data) values %s -- a%%b", - [(1, 'hi')]) - self.assert_(b'a%%b' not in cur.query) - self.assert_(b'a%b' in cur.query) + with self.conn.cursor() as cur: + psycopg2.extras.execute_values(cur, + "insert into testfast (id, data) values %s -- a%%b", + [(1, 'hi')]) + self.assert_(b'a%%b' not in cur.query) + self.assert_(b'a%b' in cur.query) - cur.execute("select id, data from testfast") - self.assertEqual(cur.fetchall(), [(1, 'hi')]) + cur.execute("select id, data from testfast") + self.assertEqual(cur.fetchall(), [(1, 'hi')]) def test_suite(): diff --git a/tests/test_green.py b/tests/test_green.py index e56ce586..635a1907 100755 --- a/tests/test_green.py +++ b/tests/test_green.py @@ -71,14 +71,14 @@ class GreenTestCase(ConnectingTestCase): # a very large query requires a flush loop to be sent to the backend conn = self.conn stub = self.set_stub_wait_callback(conn) - curs = conn.cursor() - for mb in 1, 5, 10, 20, 50: - size = mb * 1024 * 1024 - del stub.polls[:] - curs.execute("select %s;", ('x' * size,)) - self.assertEqual(size, len(curs.fetchone()[0])) - if stub.polls.count(psycopg2.extensions.POLL_WRITE) > 1: - return + with conn.cursor() as curs: + for mb in 1, 5, 10, 20, 50: + size = mb * 1024 * 1024 + del stub.polls[:] + curs.execute("select %s;", ('x' * size,)) + self.assertEqual(size, len(curs.fetchone()[0])) + if stub.polls.count(psycopg2.extensions.POLL_WRITE) > 1: + return # This is more a testing glitch than an error: it happens # on high load on linux: probably because the kernel has more @@ -105,21 +105,21 @@ class GreenTestCase(ConnectingTestCase): # if there is an error in a green query, don't freak out and close # the connection conn = self.conn - curs = conn.cursor() - self.assertRaises(psycopg2.ProgrammingError, - curs.execute, "select the unselectable") + with conn.cursor() as curs: + self.assertRaises(psycopg2.ProgrammingError, + curs.execute, "select the unselectable") - # check that the connection is left in an usable state - self.assert_(not conn.closed) - conn.rollback() - curs.execute("select 1") - self.assertEqual(curs.fetchone()[0], 1) + # check that the connection is left in an usable state + self.assert_(not conn.closed) + conn.rollback() + curs.execute("select 1") + self.assertEqual(curs.fetchone()[0], 1) @skip_before_postgres(8, 2) def test_copy_no_hang(self): - cur = self.conn.cursor() - self.assertRaises(psycopg2.ProgrammingError, - cur.execute, "copy (select 1) to stdout") + with self.conn.cursor() as cur: + self.assertRaises(psycopg2.ProgrammingError, + cur.execute, "copy (select 1) to stdout") @slow @skip_before_postgres(9, 0) @@ -137,19 +137,19 @@ class GreenTestCase(ConnectingTestCase): raise conn.OperationalError("bad state from poll: %s" % state) stub = self.set_stub_wait_callback(self.conn, wait) - cur = self.conn.cursor() - cur.execute(""" - select 1; - do $$ - begin - raise notice 'hello'; - end - $$ language plpgsql; - select pg_sleep(1); - """) + with self.conn.cursor() as cur: + cur.execute(""" + select 1; + do $$ + begin + raise notice 'hello'; + end + $$ language plpgsql; + select pg_sleep(1); + """) - polls = stub.polls.count(POLL_READ) - self.assert_(polls > 8, polls) + polls = stub.polls.count(POLL_READ) + self.assert_(polls > 8, polls) class CallbackErrorTestCase(ConnectingTestCase): @@ -203,16 +203,16 @@ class CallbackErrorTestCase(ConnectingTestCase): for i in range(100): self.to_error = None cnn = self.connect() - cur = cnn.cursor() - self.to_error = i - try: - cur.execute("select 1") - cur.fetchone() - except ZeroDivisionError: - pass - else: - # The query completed - return + with cnn.cursor() as cur: + self.to_error = i + try: + cur.execute("select 1") + cur.fetchone() + except ZeroDivisionError: + pass + else: + # The query completed + return self.fail("you should have had a success or an error by now") @@ -220,16 +220,19 @@ class CallbackErrorTestCase(ConnectingTestCase): for i in range(100): self.to_error = None cnn = self.connect() - cur = cnn.cursor('foo') - self.to_error = i - try: - cur.execute("select 1") - cur.fetchone() - except ZeroDivisionError: - pass - else: - # The query completed - return + with cnn.cursor('foo') as cur: + self.to_error = i + try: + cur.execute("select 1") + cur.fetchone() + except ZeroDivisionError: + pass + else: + # The query completed + return + finally: + # Don't raise an exception in the cursor context manager. + self.to_error = None self.fail("you should have had a success or an error by now") diff --git a/tests/test_ipaddress.py b/tests/test_ipaddress.py index ccbae291..ff92a711 100755 --- a/tests/test_ipaddress.py +++ b/tests/test_ipaddress.py @@ -33,82 +33,82 @@ except ImportError: @unittest.skipIf(ip is None, "'ipaddress' module not available") class NetworkingTestCase(testutils.ConnectingTestCase): def test_inet_cast(self): - cur = self.conn.cursor() - psycopg2.extras.register_ipaddress(cur) + with self.conn.cursor() as cur: + psycopg2.extras.register_ipaddress(cur) - cur.execute("select null::inet") - self.assert_(cur.fetchone()[0] is None) + cur.execute("select null::inet") + self.assert_(cur.fetchone()[0] is None) - cur.execute("select '127.0.0.1/24'::inet") - obj = cur.fetchone()[0] - self.assert_(isinstance(obj, ip.IPv4Interface), repr(obj)) - self.assertEquals(obj, ip.ip_interface('127.0.0.1/24')) + cur.execute("select '127.0.0.1/24'::inet") + obj = cur.fetchone()[0] + self.assert_(isinstance(obj, ip.IPv4Interface), repr(obj)) + self.assertEquals(obj, ip.ip_interface('127.0.0.1/24')) - cur.execute("select '::ffff:102:300/128'::inet") - obj = cur.fetchone()[0] - self.assert_(isinstance(obj, ip.IPv6Interface), repr(obj)) - self.assertEquals(obj, ip.ip_interface('::ffff:102:300/128')) + cur.execute("select '::ffff:102:300/128'::inet") + obj = cur.fetchone()[0] + self.assert_(isinstance(obj, ip.IPv6Interface), repr(obj)) + self.assertEquals(obj, ip.ip_interface('::ffff:102:300/128')) @testutils.skip_before_postgres(8, 2) def test_inet_array_cast(self): - cur = self.conn.cursor() - psycopg2.extras.register_ipaddress(cur) - cur.execute("select '{NULL,127.0.0.1,::ffff:102:300/128}'::inet[]") - l = cur.fetchone()[0] - self.assert_(l[0] is None) - self.assertEquals(l[1], ip.ip_interface('127.0.0.1')) - self.assertEquals(l[2], ip.ip_interface('::ffff:102:300/128')) - self.assert_(isinstance(l[1], ip.IPv4Interface), l) - self.assert_(isinstance(l[2], ip.IPv6Interface), l) + with self.conn.cursor() as cur: + psycopg2.extras.register_ipaddress(cur) + cur.execute("select '{NULL,127.0.0.1,::ffff:102:300/128}'::inet[]") + l = cur.fetchone()[0] + self.assert_(l[0] is None) + self.assertEquals(l[1], ip.ip_interface('127.0.0.1')) + self.assertEquals(l[2], ip.ip_interface('::ffff:102:300/128')) + self.assert_(isinstance(l[1], ip.IPv4Interface), l) + self.assert_(isinstance(l[2], ip.IPv6Interface), l) def test_inet_adapt(self): - cur = self.conn.cursor() - psycopg2.extras.register_ipaddress(cur) + with self.conn.cursor() as cur: + psycopg2.extras.register_ipaddress(cur) - cur.execute("select %s", [ip.ip_interface('127.0.0.1/24')]) - self.assertEquals(cur.fetchone()[0], '127.0.0.1/24') + cur.execute("select %s", [ip.ip_interface('127.0.0.1/24')]) + self.assertEquals(cur.fetchone()[0], '127.0.0.1/24') - cur.execute("select %s", [ip.ip_interface('::ffff:102:300/128')]) - self.assertEquals(cur.fetchone()[0], '::ffff:102:300/128') + cur.execute("select %s", [ip.ip_interface('::ffff:102:300/128')]) + self.assertEquals(cur.fetchone()[0], '::ffff:102:300/128') def test_cidr_cast(self): - cur = self.conn.cursor() - psycopg2.extras.register_ipaddress(cur) + with self.conn.cursor() as cur: + psycopg2.extras.register_ipaddress(cur) - cur.execute("select null::cidr") - self.assert_(cur.fetchone()[0] is None) + cur.execute("select null::cidr") + self.assert_(cur.fetchone()[0] is None) - cur.execute("select '127.0.0.0/24'::cidr") - obj = cur.fetchone()[0] - self.assert_(isinstance(obj, ip.IPv4Network), repr(obj)) - self.assertEquals(obj, ip.ip_network('127.0.0.0/24')) + cur.execute("select '127.0.0.0/24'::cidr") + obj = cur.fetchone()[0] + self.assert_(isinstance(obj, ip.IPv4Network), repr(obj)) + self.assertEquals(obj, ip.ip_network('127.0.0.0/24')) - cur.execute("select '::ffff:102:300/128'::cidr") - obj = cur.fetchone()[0] - self.assert_(isinstance(obj, ip.IPv6Network), repr(obj)) - self.assertEquals(obj, ip.ip_network('::ffff:102:300/128')) + cur.execute("select '::ffff:102:300/128'::cidr") + obj = cur.fetchone()[0] + self.assert_(isinstance(obj, ip.IPv6Network), repr(obj)) + self.assertEquals(obj, ip.ip_network('::ffff:102:300/128')) @testutils.skip_before_postgres(8, 2) def test_cidr_array_cast(self): - cur = self.conn.cursor() - psycopg2.extras.register_ipaddress(cur) - cur.execute("select '{NULL,127.0.0.1,::ffff:102:300/128}'::cidr[]") - l = cur.fetchone()[0] - self.assert_(l[0] is None) - self.assertEquals(l[1], ip.ip_network('127.0.0.1')) - self.assertEquals(l[2], ip.ip_network('::ffff:102:300/128')) - self.assert_(isinstance(l[1], ip.IPv4Network), l) - self.assert_(isinstance(l[2], ip.IPv6Network), l) + with self.conn.cursor() as cur: + psycopg2.extras.register_ipaddress(cur) + cur.execute("select '{NULL,127.0.0.1,::ffff:102:300/128}'::cidr[]") + l = cur.fetchone()[0] + self.assert_(l[0] is None) + self.assertEquals(l[1], ip.ip_network('127.0.0.1')) + self.assertEquals(l[2], ip.ip_network('::ffff:102:300/128')) + self.assert_(isinstance(l[1], ip.IPv4Network), l) + self.assert_(isinstance(l[2], ip.IPv6Network), l) def test_cidr_adapt(self): - cur = self.conn.cursor() - psycopg2.extras.register_ipaddress(cur) + with self.conn.cursor() as cur: + psycopg2.extras.register_ipaddress(cur) - cur.execute("select %s", [ip.ip_network('127.0.0.0/24')]) - self.assertEquals(cur.fetchone()[0], '127.0.0.0/24') + cur.execute("select %s", [ip.ip_network('127.0.0.0/24')]) + self.assertEquals(cur.fetchone()[0], '127.0.0.0/24') - cur.execute("select %s", [ip.ip_network('::ffff:102:300/128')]) - self.assertEquals(cur.fetchone()[0], '::ffff:102:300/128') + cur.execute("select %s", [ip.ip_network('::ffff:102:300/128')]) + self.assertEquals(cur.fetchone()[0], '::ffff:102:300/128') def test_suite(): diff --git a/tests/test_module.py b/tests/test_module.py index 416e6237..85153fdf 100755 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -154,22 +154,22 @@ class ConnectTestCase(unittest.TestCase): class ExceptionsTestCase(ConnectingTestCase): def test_attributes(self): - cur = self.conn.cursor() - try: - cur.execute("select * from nonexist") - except psycopg2.Error as exc: - e = exc + with self.conn.cursor() as cur: + try: + cur.execute("select * from nonexist") + except psycopg2.Error as exc: + e = exc self.assertEqual(e.pgcode, '42P01') self.assert_(e.pgerror) self.assert_(e.cursor is cur) def test_diagnostics_attributes(self): - cur = self.conn.cursor() - try: - cur.execute("select * from nonexist") - except psycopg2.Error as exc: - e = exc + with self.conn.cursor() as cur: + try: + cur.execute("select * from nonexist") + except psycopg2.Error as exc: + e = exc diag = e.diag self.assert_(isinstance(diag, psycopg2.extensions.Diagnostics)) @@ -195,11 +195,11 @@ class ExceptionsTestCase(ConnectingTestCase): def test_diagnostics_life(self): def tmp(): - cur = self.conn.cursor() - try: - cur.execute("select * from nonexist") - except psycopg2.Error as exc: - return cur, exc + with self.conn.cursor() as cur: + try: + cur.execute("select * from nonexist") + except psycopg2.Error as exc: + return cur, exc cur, e = tmp() diag = e.diag diff --git a/tests/test_notify.py b/tests/test_notify.py index 51865e64..ce3e0b03 100755 --- a/tests/test_notify.py +++ b/tests/test_notify.py @@ -120,10 +120,11 @@ conn.close() self.listen('foo') pid = int(self.notify('foo').communicate()[0]) self.assertEqual(0, len(self.conn.notifies)) - self.conn.cursor().execute('select 1;') - self.assertEqual(1, len(self.conn.notifies)) - self.assertEqual(pid, self.conn.notifies[0][0]) - self.assertEqual('foo', self.conn.notifies[0][1]) + with self.conn.cursor() as cur: + cur.execute('select 1;') + self.assertEqual(1, len(self.conn.notifies)) + self.assertEqual(pid, self.conn.notifies[0][0]) + self.assertEqual('foo', self.conn.notifies[0][1]) @slow def test_notify_object(self): diff --git a/tests/test_quote.py b/tests/test_quote.py index 42e90eb7..7d3759a8 100755 --- a/tests/test_quote.py +++ b/tests/test_quote.py @@ -56,24 +56,24 @@ class QuotingTestCase(ConnectingTestCase): """ data += "".join(map(chr, range(1, 127))) - curs = self.conn.cursor() - curs.execute("SELECT %s;", (data,)) - res = curs.fetchone()[0] + with self.conn.cursor() as curs: + curs.execute("SELECT %s;", (data,)) + res = curs.fetchone()[0] self.assertEqual(res, data) self.assert_(not self.conn.notices) def test_string_null_terminator(self): - curs = self.conn.cursor() - data = 'abcd\x01\x00cdefg' + with self.conn.cursor() as curs: + data = 'abcd\x01\x00cdefg' - try: - curs.execute("SELECT %s", (data,)) - except ValueError as e: - self.assertEquals(str(e), - 'A string literal cannot contain NUL (0x00) characters.') - else: - self.fail("ValueError not raised") + try: + curs.execute("SELECT %s", (data,)) + except ValueError as e: + self.assertEquals(str(e), + 'A string literal cannot contain NUL (0x00) characters.') + else: + self.fail("ValueError not raised") def test_binary(self): data = b"""some data with \000\013 binary @@ -84,12 +84,12 @@ class QuotingTestCase(ConnectingTestCase): else: data += bytes(list(range(256))) - curs = self.conn.cursor() - curs.execute("SELECT %s::bytea;", (psycopg2.Binary(data),)) - if PY2: - res = str(curs.fetchone()[0]) - else: - res = curs.fetchone()[0].tobytes() + with self.conn.cursor() as curs: + curs.execute("SELECT %s::bytea;", (psycopg2.Binary(data),)) + if PY2: + res = str(curs.fetchone()[0]) + else: + res = curs.fetchone()[0].tobytes() if res[0] in (b'x', ord(b'x')) and self.conn.info.server_version >= 90000: return self.skipTest( @@ -99,86 +99,87 @@ class QuotingTestCase(ConnectingTestCase): self.assert_(not self.conn.notices) def test_unicode(self): - curs = self.conn.cursor() - curs.execute("SHOW server_encoding") - server_encoding = curs.fetchone()[0] - if server_encoding != "UTF8": - return self.skipTest( - "Unicode test skipped since server encoding is %s" - % server_encoding) + with self.conn.cursor() as curs: + curs.execute("SHOW server_encoding") + server_encoding = curs.fetchone()[0] + if server_encoding != "UTF8": + return self.skipTest( + "Unicode test skipped since server encoding is %s" + % server_encoding) - data = u"""some data with \t chars - to escape into, 'quotes', \u20ac euro sign and \\ a backslash too. - """ - data += u"".join(map(unichr, [u for u in range(1, 65536) - if not 0xD800 <= u <= 0xDFFF])) # surrogate area - self.conn.set_client_encoding('UNICODE') + data = u"""some data with \t chars + to escape into, 'quotes', \u20ac euro sign and \\ a backslash too. + """ + data += u"".join(map(unichr, [u for u in range(1, 65536) + if not 0xD800 <= u <= 0xDFFF])) # surrogate area + self.conn.set_client_encoding('UNICODE') - psycopg2.extensions.register_type(psycopg2.extensions.UNICODE, self.conn) - curs.execute("SELECT %s::text;", (data,)) - res = curs.fetchone()[0] + psycopg2.extensions.register_type(psycopg2.extensions.UNICODE, self.conn) + curs.execute("SELECT %s::text;", (data,)) + res = curs.fetchone()[0] - self.assertEqual(res, data) - self.assert_(not self.conn.notices) + self.assertEqual(res, data) + self.assert_(not self.conn.notices) def test_latin1(self): self.conn.set_client_encoding('LATIN1') - curs = self.conn.cursor() - if PY2: - data = ''.join(map(chr, range(32, 127) + range(160, 256))) - else: - data = bytes(list(range(32, 127)) - + list(range(160, 256))).decode('latin1') - - # as string - curs.execute("SELECT %s::text;", (data,)) - res = curs.fetchone()[0] - self.assertEqual(res, data) - self.assert_(not self.conn.notices) - - # as unicode - if PY2: - psycopg2.extensions.register_type(psycopg2.extensions.UNICODE, self.conn) - data = data.decode('latin1') + with self.conn.cursor() as curs: + if PY2: + data = ''.join(map(chr, range(32, 127) + range(160, 256))) + else: + data = bytes(list(range(32, 127)) + + list(range(160, 256))).decode('latin1') + # as string curs.execute("SELECT %s::text;", (data,)) res = curs.fetchone()[0] self.assertEqual(res, data) self.assert_(not self.conn.notices) + # as unicode + if PY2: + psycopg2.extensions.register_type( + psycopg2.extensions.UNICODE, self.conn) + data = data.decode('latin1') + + curs.execute("SELECT %s::text;", (data,)) + res = curs.fetchone()[0] + self.assertEqual(res, data) + self.assert_(not self.conn.notices) + def test_koi8(self): self.conn.set_client_encoding('KOI8') - curs = self.conn.cursor() - if PY2: - data = ''.join(map(chr, range(32, 127) + range(128, 256))) - else: - data = bytes(list(range(32, 127)) - + list(range(128, 256))).decode('koi8_r') - - # as string - curs.execute("SELECT %s::text;", (data,)) - res = curs.fetchone()[0] - self.assertEqual(res, data) - self.assert_(not self.conn.notices) - - # as unicode - if PY2: - psycopg2.extensions.register_type(psycopg2.extensions.UNICODE, self.conn) - data = data.decode('koi8_r') + with self.conn.cursor() as curs: + if PY2: + data = ''.join(map(chr, range(32, 127) + range(128, 256))) + else: + data = bytes(list(range(32, 127)) + + list(range(128, 256))).decode('koi8_r') + # as string curs.execute("SELECT %s::text;", (data,)) res = curs.fetchone()[0] self.assertEqual(res, data) self.assert_(not self.conn.notices) + # as unicode + if PY2: + psycopg2.extensions.register_type(psycopg2.extensions.UNICODE, self.conn) + data = data.decode('koi8_r') + + curs.execute("SELECT %s::text;", (data,)) + res = curs.fetchone()[0] + self.assertEqual(res, data) + self.assert_(not self.conn.notices) + def test_bytes(self): snowman = u"\u2603" conn = self.connect() conn.set_client_encoding('UNICODE') psycopg2.extensions.register_type(psycopg2.extensions.BYTES, conn) - curs = conn.cursor() - curs.execute("select %s::text", (snowman,)) - x = curs.fetchone()[0] + with conn.cursor() as curs: + curs.execute("select %s::text", (snowman,)) + x = curs.fetchone()[0] self.assert_(isinstance(x, bytes)) self.assertEqual(x, snowman.encode('utf8')) diff --git a/tests/test_replication.py b/tests/test_replication.py index 3ed68a57..e5564eb2 100755 --- a/tests/test_replication.py +++ b/tests/test_replication.py @@ -73,14 +73,13 @@ class ReplicationTestCase(ConnectingTestCase): conn = self.connect() if conn is None: return - cur = conn.cursor() - - try: - cur.execute("DROP TABLE dummy1") - except psycopg2.ProgrammingError: - conn.rollback() - cur.execute( - "CREATE TABLE dummy1 AS SELECT * FROM generate_series(1, 5) AS id") + with conn.cursor() as cur: + try: + cur.execute("DROP TABLE dummy1") + except psycopg2.ProgrammingError: + conn.rollback() + cur.execute( + "CREATE TABLE dummy1 AS SELECT * FROM generate_series(1, 5) AS id") conn.commit() @@ -90,9 +89,9 @@ class ReplicationTest(ReplicationTestCase): conn = self.repl_connect(connection_factory=PhysicalReplicationConnection) if conn is None: return - cur = conn.cursor() - cur.execute("IDENTIFY_SYSTEM") - cur.fetchall() + with conn.cursor() as cur: + cur.execute("IDENTIFY_SYSTEM") + cur.fetchall() @skip_before_postgres(9, 0) def test_datestyle(self): @@ -104,29 +103,28 @@ class ReplicationTest(ReplicationTestCase): connection_factory=PhysicalReplicationConnection) if conn is None: return - cur = conn.cursor() - cur.execute("IDENTIFY_SYSTEM") - cur.fetchall() + with conn.cursor() as cur: + cur.execute("IDENTIFY_SYSTEM") + cur.fetchall() @skip_before_postgres(9, 4) def test_logical_replication_connection(self): conn = self.repl_connect(connection_factory=LogicalReplicationConnection) if conn is None: return - cur = conn.cursor() - cur.execute("IDENTIFY_SYSTEM") - cur.fetchall() + with conn.cursor() as cur: + cur.execute("IDENTIFY_SYSTEM") + cur.fetchall() @skip_before_postgres(9, 4) # slots require 9.4 def test_create_replication_slot(self): conn = self.repl_connect(connection_factory=PhysicalReplicationConnection) if conn is None: return - cur = conn.cursor() - - self.create_replication_slot(cur) - self.assertRaises( - psycopg2.ProgrammingError, self.create_replication_slot, cur) + with conn.cursor() as cur: + self.create_replication_slot(cur) + self.assertRaises( + psycopg2.ProgrammingError, self.create_replication_slot, cur) @skip_before_postgres(9, 4) # slots require 9.4 @skip_repl_if_green @@ -134,13 +132,12 @@ class ReplicationTest(ReplicationTestCase): conn = self.repl_connect(connection_factory=PhysicalReplicationConnection) if conn is None: return - cur = conn.cursor() + with conn.cursor() as cur: + self.assertRaises(psycopg2.ProgrammingError, + cur.start_replication, self.slot) - self.assertRaises(psycopg2.ProgrammingError, - cur.start_replication, self.slot) - - self.create_replication_slot(cur) - cur.start_replication(self.slot) + self.create_replication_slot(cur) + cur.start_replication(self.slot) @skip_before_postgres(9, 4) # slots require 9.4 @skip_repl_if_green @@ -148,12 +145,11 @@ class ReplicationTest(ReplicationTestCase): conn = self.repl_connect(connection_factory=LogicalReplicationConnection) if conn is None: return - cur = conn.cursor() - - self.create_replication_slot(cur, output_plugin='test_decoding') - cur.start_replication_expert( - sql.SQL("START_REPLICATION SLOT {slot} LOGICAL 0/00000000").format( - slot=sql.Identifier(self.slot))) + with conn.cursor() as cur: + self.create_replication_slot(cur, output_plugin='test_decoding') + cur.start_replication_expert( + sql.SQL("START_REPLICATION SLOT {slot} LOGICAL 0/00000000").format( + slot=sql.Identifier(self.slot))) @skip_before_postgres(9, 4) # slots require 9.4 @skip_repl_if_green @@ -161,23 +157,22 @@ class ReplicationTest(ReplicationTestCase): conn = self.repl_connect(connection_factory=LogicalReplicationConnection) if conn is None: return - cur = conn.cursor() + with conn.cursor() as cur: + self.create_replication_slot(cur, output_plugin='test_decoding') + self.make_replication_events() - self.create_replication_slot(cur, output_plugin='test_decoding') - self.make_replication_events() + def consume(msg): + raise StopReplication() - def consume(msg): - raise StopReplication() + with self.assertRaises(psycopg2.DataError): + # try with invalid options + cur.start_replication( + slot_name=self.slot, options={'invalid_param': 'value'}) + cur.consume_stream(consume) - with self.assertRaises(psycopg2.DataError): - # try with invalid options - cur.start_replication( - slot_name=self.slot, options={'invalid_param': 'value'}) - cur.consume_stream(consume) - - # try with correct command - cur.start_replication(slot_name=self.slot) - self.assertRaises(StopReplication, cur.consume_stream, consume) + # try with correct command + cur.start_replication(slot_name=self.slot) + self.assertRaises(StopReplication, cur.consume_stream, consume) @skip_before_postgres(9, 4) # slots require 9.4 @skip_repl_if_green @@ -208,17 +203,16 @@ class ReplicationTest(ReplicationTestCase): conn = self.repl_connect(connection_factory=LogicalReplicationConnection) if conn is None: return - cur = conn.cursor() + with conn.cursor() as cur: + self.create_replication_slot(cur, output_plugin='test_decoding') - self.create_replication_slot(cur, output_plugin='test_decoding') + self.make_replication_events() - self.make_replication_events() + cur.start_replication(self.slot) - cur.start_replication(self.slot) - - def consume(msg): - raise StopReplication() - self.assertRaises(StopReplication, cur.consume_stream, consume) + def consume(msg): + raise StopReplication() + self.assertRaises(StopReplication, cur.consume_stream, consume) class AsyncReplicationTest(ReplicationTestCase): @@ -230,42 +224,41 @@ class AsyncReplicationTest(ReplicationTestCase): if conn is None: return - cur = conn.cursor() + with conn.cursor() as cur: + self.create_replication_slot(cur, output_plugin='test_decoding') + self.wait(cur) - self.create_replication_slot(cur, output_plugin='test_decoding') - self.wait(cur) + cur.start_replication(self.slot) + self.wait(cur) - cur.start_replication(self.slot) - self.wait(cur) + self.make_replication_events() - self.make_replication_events() + self.msg_count = 0 - self.msg_count = 0 + def consume(msg): + # just check the methods + "%s: %s" % (cur.io_timestamp, repr(msg)) + "%s: %s" % (cur.feedback_timestamp, repr(msg)) + "%s: %s" % (cur.wal_end, repr(msg)) - def consume(msg): - # just check the methods - "%s: %s" % (cur.io_timestamp, repr(msg)) - "%s: %s" % (cur.feedback_timestamp, repr(msg)) - "%s: %s" % (cur.wal_end, repr(msg)) + self.msg_count += 1 + if self.msg_count > 3: + cur.send_feedback(reply=True) + raise StopReplication() - self.msg_count += 1 - if self.msg_count > 3: - cur.send_feedback(reply=True) - raise StopReplication() + cur.send_feedback(flush_lsn=msg.data_start) - cur.send_feedback(flush_lsn=msg.data_start) + # cannot be used in asynchronous mode + self.assertRaises(psycopg2.ProgrammingError, cur.consume_stream, consume) - # cannot be used in asynchronous mode - self.assertRaises(psycopg2.ProgrammingError, cur.consume_stream, consume) - - def process_stream(): - while True: - msg = cur.read_message() - if msg: - consume(msg) - else: - select([cur], [], []) - self.assertRaises(StopReplication, process_stream) + def process_stream(): + while True: + msg = cur.read_message() + if msg: + consume(msg) + else: + select([cur], [], []) + self.assertRaises(StopReplication, process_stream) def test_suite(): diff --git a/tests/test_sql.py b/tests/test_sql.py index 9089ae77..53ab952d 100755 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -117,61 +117,61 @@ class SqlFormatTests(ConnectingTestCase): sql.SQL("select {0};").format(sql.Literal(Foo())).as_string, self.conn) def test_execute(self): - cur = self.conn.cursor() - cur.execute(""" - create table test_compose ( - id serial primary key, - foo text, bar text, "ba'z" text) - """) - cur.execute( - sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format( - sql.Identifier('test_compose'), - sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])), - (sql.Placeholder() * 3).join(', ')), - (10, 'a', 'b', 'c')) + with self.conn.cursor() as cur: + cur.execute(""" + create table test_compose ( + id serial primary key, + foo text, bar text, "ba'z" text) + """) + cur.execute( + sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format( + sql.Identifier('test_compose'), + sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])), + (sql.Placeholder() * 3).join(', ')), + (10, 'a', 'b', 'c')) - cur.execute("select * from test_compose") - self.assertEqual(cur.fetchall(), [(10, 'a', 'b', 'c')]) + cur.execute("select * from test_compose") + self.assertEqual(cur.fetchall(), [(10, 'a', 'b', 'c')]) def test_executemany(self): - cur = self.conn.cursor() - cur.execute(""" - create table test_compose ( - id serial primary key, - foo text, bar text, "ba'z" text) - """) - cur.executemany( - sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format( - sql.Identifier('test_compose'), - sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])), - (sql.Placeholder() * 3).join(', ')), - [(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')]) + with self.conn.cursor() as cur: + cur.execute(""" + create table test_compose ( + id serial primary key, + foo text, bar text, "ba'z" text) + """) + cur.executemany( + sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format( + sql.Identifier('test_compose'), + sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])), + (sql.Placeholder() * 3).join(', ')), + [(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')]) - cur.execute("select * from test_compose") - self.assertEqual(cur.fetchall(), - [(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')]) + cur.execute("select * from test_compose") + self.assertEqual(cur.fetchall(), + [(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')]) @skip_copy_if_green @skip_before_postgres(8, 2) def test_copy(self): - cur = self.conn.cursor() - cur.execute(""" - create table test_compose ( - id serial primary key, - foo text, bar text, "ba'z" text) - """) + with self.conn.cursor() as cur: + cur.execute(""" + create table test_compose ( + id serial primary key, + foo text, bar text, "ba'z" text) + """) - s = StringIO("10\ta\tb\tc\n20\td\te\tf\n") - cur.copy_expert( - sql.SQL("copy {t} (id, foo, bar, {f}) from stdin").format( - t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")), s) + s = StringIO("10\ta\tb\tc\n20\td\te\tf\n") + cur.copy_expert( + sql.SQL("copy {t} (id, foo, bar, {f}) from stdin").format( + t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")), s) - s1 = StringIO() - cur.copy_expert( - sql.SQL("copy (select {f} from {t} order by id) to stdout").format( - t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")), s1) - s1.seek(0) - self.assertEqual(s1.read(), 'c\nf\n') + s1 = StringIO() + cur.copy_expert( + sql.SQL("copy (select {f} from {t} order by id) to stdout").format( + t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")), s1) + s1.seek(0) + self.assertEqual(s1.read(), 'c\nf\n') class IdentifierTests(ConnectingTestCase): diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 25feba0b..ef4cb6bc 100755 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -37,59 +37,59 @@ class TransactionTests(ConnectingTestCase): def setUp(self): ConnectingTestCase.setUp(self) self.conn.set_isolation_level(ISOLATION_LEVEL_SERIALIZABLE) - curs = self.conn.cursor() - curs.execute(''' - CREATE TEMPORARY TABLE table1 ( - id int PRIMARY KEY - )''') - # The constraint is set to deferrable for the commit_failed test - curs.execute(''' - CREATE TEMPORARY TABLE table2 ( - id int PRIMARY KEY, - table1_id int, - CONSTRAINT table2__table1_id__fk - FOREIGN KEY (table1_id) REFERENCES table1(id) DEFERRABLE)''') - curs.execute('INSERT INTO table1 VALUES (1)') - curs.execute('INSERT INTO table2 VALUES (1, 1)') + with self.conn.cursor() as curs: + curs.execute(''' + CREATE TEMPORARY TABLE table1 ( + id int PRIMARY KEY + )''') + # The constraint is set to deferrable for the commit_failed test + curs.execute(''' + CREATE TEMPORARY TABLE table2 ( + id int PRIMARY KEY, + table1_id int, + CONSTRAINT table2__table1_id__fk + FOREIGN KEY (table1_id) REFERENCES table1(id) DEFERRABLE)''') + curs.execute('INSERT INTO table1 VALUES (1)') + curs.execute('INSERT INTO table2 VALUES (1, 1)') self.conn.commit() def test_rollback(self): # Test that rollback undoes changes - curs = self.conn.cursor() - curs.execute('INSERT INTO table2 VALUES (2, 1)') - # Rollback takes us from BEGIN state to READY state - self.assertEqual(self.conn.status, STATUS_BEGIN) - self.conn.rollback() - self.assertEqual(self.conn.status, STATUS_READY) - curs.execute('SELECT id, table1_id FROM table2 WHERE id = 2') - self.assertEqual(curs.fetchall(), []) + with self.conn.cursor() as curs: + curs.execute('INSERT INTO table2 VALUES (2, 1)') + # Rollback takes us from BEGIN state to READY state + self.assertEqual(self.conn.status, STATUS_BEGIN) + self.conn.rollback() + self.assertEqual(self.conn.status, STATUS_READY) + curs.execute('SELECT id, table1_id FROM table2 WHERE id = 2') + self.assertEqual(curs.fetchall(), []) def test_commit(self): # Test that commit stores changes - curs = self.conn.cursor() - curs.execute('INSERT INTO table2 VALUES (2, 1)') - # Rollback takes us from BEGIN state to READY state - self.assertEqual(self.conn.status, STATUS_BEGIN) - self.conn.commit() - self.assertEqual(self.conn.status, STATUS_READY) - # Now rollback and show that the new record is still there: - self.conn.rollback() - curs.execute('SELECT id, table1_id FROM table2 WHERE id = 2') - self.assertEqual(curs.fetchall(), [(2, 1)]) + with self.conn.cursor() as curs: + curs.execute('INSERT INTO table2 VALUES (2, 1)') + # Rollback takes us from BEGIN state to READY state + self.assertEqual(self.conn.status, STATUS_BEGIN) + self.conn.commit() + self.assertEqual(self.conn.status, STATUS_READY) + # Now rollback and show that the new record is still there: + self.conn.rollback() + curs.execute('SELECT id, table1_id FROM table2 WHERE id = 2') + self.assertEqual(curs.fetchall(), [(2, 1)]) def test_failed_commit(self): # Test that we can recover from a failed commit. # We use a deferred constraint to cause a failure on commit. - curs = self.conn.cursor() - curs.execute('SET CONSTRAINTS table2__table1_id__fk DEFERRED') - curs.execute('INSERT INTO table2 VALUES (2, 42)') - # The commit should fail, and move the cursor back to READY state - self.assertEqual(self.conn.status, STATUS_BEGIN) - self.assertRaises(psycopg2.IntegrityError, self.conn.commit) - self.assertEqual(self.conn.status, STATUS_READY) - # The connection should be ready to use for the next transaction: - curs.execute('SELECT 1') - self.assertEqual(curs.fetchone()[0], 1) + with self.conn.cursor() as curs: + curs.execute('SET CONSTRAINTS table2__table1_id__fk DEFERRED') + curs.execute('INSERT INTO table2 VALUES (2, 42)') + # The commit should fail, and move the cursor back to READY state + self.assertEqual(self.conn.status, STATUS_BEGIN) + self.assertRaises(psycopg2.IntegrityError, self.conn.commit) + self.assertEqual(self.conn.status, STATUS_READY) + # The connection should be ready to use for the next transaction: + curs.execute('SELECT 1') + self.assertEqual(curs.fetchone()[0], 1) class DeadlockSerializationTests(ConnectingTestCase): @@ -103,32 +103,32 @@ class DeadlockSerializationTests(ConnectingTestCase): def setUp(self): ConnectingTestCase.setUp(self) - curs = self.conn.cursor() - # Drop table if it already exists - try: - curs.execute("DROP TABLE table1") - self.conn.commit() - except psycopg2.DatabaseError: - self.conn.rollback() - try: - curs.execute("DROP TABLE table2") - self.conn.commit() - except psycopg2.DatabaseError: - self.conn.rollback() - # Create sample data - curs.execute(""" - CREATE TABLE table1 ( - id int PRIMARY KEY, - name text) - """) - curs.execute("INSERT INTO table1 VALUES (1, 'hello')") - curs.execute("CREATE TABLE table2 (id int PRIMARY KEY)") + with self.conn.cursor() as curs: + # Drop table if it already exists + try: + curs.execute("DROP TABLE table1") + self.conn.commit() + except psycopg2.DatabaseError: + self.conn.rollback() + try: + curs.execute("DROP TABLE table2") + self.conn.commit() + except psycopg2.DatabaseError: + self.conn.rollback() + # Create sample data + curs.execute(""" + CREATE TABLE table1 ( + id int PRIMARY KEY, + name text) + """) + curs.execute("INSERT INTO table1 VALUES (1, 'hello')") + curs.execute("CREATE TABLE table2 (id int PRIMARY KEY)") self.conn.commit() def tearDown(self): - curs = self.conn.cursor() - curs.execute("DROP TABLE table1") - curs.execute("DROP TABLE table2") + with self.conn.cursor() as curs: + curs.execute("DROP TABLE table1") + curs.execute("DROP TABLE table2") self.conn.commit() ConnectingTestCase.tearDown(self) @@ -142,11 +142,11 @@ class DeadlockSerializationTests(ConnectingTestCase): def task1(): try: conn = self.connect() - curs = conn.cursor() - curs.execute("LOCK table1 IN ACCESS EXCLUSIVE MODE") - step1.set() - step2.wait() - curs.execute("LOCK table2 IN ACCESS EXCLUSIVE MODE") + with conn.cursor() as curs: + curs.execute("LOCK table1 IN ACCESS EXCLUSIVE MODE") + step1.set() + step2.wait() + curs.execute("LOCK table2 IN ACCESS EXCLUSIVE MODE") except psycopg2.DatabaseError as exc: self.thread1_error = exc step1.set() @@ -155,11 +155,11 @@ class DeadlockSerializationTests(ConnectingTestCase): def task2(): try: conn = self.connect() - curs = conn.cursor() - step1.wait() - curs.execute("LOCK table2 IN ACCESS EXCLUSIVE MODE") - step2.set() - curs.execute("LOCK table1 IN ACCESS EXCLUSIVE MODE") + with conn.cursor() as curs: + step1.wait() + curs.execute("LOCK table2 IN ACCESS EXCLUSIVE MODE") + step2.set() + curs.execute("LOCK table1 IN ACCESS EXCLUSIVE MODE") except psycopg2.DatabaseError as exc: self.thread2_error = exc step2.set() @@ -190,12 +190,12 @@ class DeadlockSerializationTests(ConnectingTestCase): def task1(): try: conn = self.connect() - curs = conn.cursor() - curs.execute("SELECT name FROM table1 WHERE id = 1") - curs.fetchall() - step1.set() - step2.wait() - curs.execute("UPDATE table1 SET name='task1' WHERE id = 1") + with conn.cursor() as curs: + curs.execute("SELECT name FROM table1 WHERE id = 1") + curs.fetchall() + step1.set() + step2.wait() + curs.execute("UPDATE table1 SET name='task1' WHERE id = 1") conn.commit() except psycopg2.DatabaseError as exc: self.thread1_error = exc @@ -205,9 +205,9 @@ class DeadlockSerializationTests(ConnectingTestCase): def task2(): try: conn = self.connect() - curs = conn.cursor() - step1.wait() - curs.execute("UPDATE table1 SET name='task2' WHERE id = 1") + with conn.cursor() as curs: + step1.wait() + curs.execute("UPDATE table1 SET name='task2' WHERE id = 1") conn.commit() except psycopg2.DatabaseError as exc: self.thread2_error = exc @@ -240,11 +240,11 @@ class QueryCancellationTests(ConnectingTestCase): @skip_before_postgres(8, 2) def test_statement_timeout(self): - curs = self.conn.cursor() - # Set a low statement timeout, then sleep for a longer period. - curs.execute('SET statement_timeout TO 10') - self.assertRaises(psycopg2.extensions.QueryCanceledError, - curs.execute, 'SELECT pg_sleep(50)') + with self.conn.cursor() as curs: + # Set a low statement timeout, then sleep for a longer period. + curs.execute('SET statement_timeout TO 10') + self.assertRaises(psycopg2.extensions.QueryCanceledError, + curs.execute, 'SELECT pg_sleep(50)') def test_suite(): diff --git a/tests/test_types_basic.py b/tests/test_types_basic.py index a6b7af03..84d22597 100755 --- a/tests/test_types_basic.py +++ b/tests/test_types_basic.py @@ -41,9 +41,9 @@ class TypesBasicTests(ConnectingTestCase): """Test that all type conversions are working.""" def execute(self, *args): - curs = self.conn.cursor() - curs.execute(*args) - return curs.fetchone()[0] + with self.conn.cursor() as curs: + curs.execute(*args) + return curs.fetchone()[0] def testQuoting(self): s = "Quote'this\\! ''ok?''" @@ -156,26 +156,27 @@ class TypesBasicTests(ConnectingTestCase): def testEmptyArrayRegression(self): # ticket #42 - curs = self.conn.cursor() - curs.execute( - "create table array_test " - "(id integer, col timestamp without time zone[])") + with self.conn.cursor() as curs: + curs.execute( + "create table array_test " + "(id integer, col timestamp without time zone[])") - curs.execute("insert into array_test values (%s, %s)", - (1, [datetime.date(2011, 2, 14)])) - curs.execute("select col from array_test where id = 1") - self.assertEqual(curs.fetchone()[0], [datetime.datetime(2011, 2, 14, 0, 0)]) + curs.execute("insert into array_test values (%s, %s)", + (1, [datetime.date(2011, 2, 14)])) + curs.execute("select col from array_test where id = 1") + self.assertEqual( + curs.fetchone()[0], [datetime.datetime(2011, 2, 14, 0, 0)]) - curs.execute("insert into array_test values (%s, %s)", (2, [])) - curs.execute("select col from array_test where id = 2") - self.assertEqual(curs.fetchone()[0], []) + curs.execute("insert into array_test values (%s, %s)", (2, [])) + curs.execute("select col from array_test where id = 2") + self.assertEqual(curs.fetchone()[0], []) @testutils.skip_before_postgres(8, 4) def testNestedEmptyArray(self): # issue #788 - curs = self.conn.cursor() - curs.execute("select 10 = any(%s::int[])", ([[]], )) - self.assertFalse(curs.fetchone()[0]) + with self.conn.cursor() as curs: + curs.execute("select 10 = any(%s::int[])", ([[]], )) + self.assertFalse(curs.fetchone()[0]) def testEmptyArrayNoCast(self): s = self.execute("SELECT '{}' AS foo") @@ -204,86 +205,86 @@ class TypesBasicTests(ConnectingTestCase): self.failUnlessEqual(ss, r) def testArrayMalformed(self): - curs = self.conn.cursor() - ss = ['', '{', '{}}', '{' * 20 + '}' * 20] - for s in ss: - self.assertRaises(psycopg2.DataError, - psycopg2.extensions.STRINGARRAY, s.encode('utf8'), curs) + with self.conn.cursor() as curs: + ss = ['', '{', '{}}', '{' * 20 + '}' * 20] + for s in ss: + self.assertRaises(psycopg2.DataError, + psycopg2.extensions.STRINGARRAY, s.encode('utf8'), curs) def testTextArray(self): - curs = self.conn.cursor() - curs.execute("select '{a,b,c}'::text[]") - x = curs.fetchone()[0] + with self.conn.cursor() as curs: + curs.execute("select '{a,b,c}'::text[]") + x = curs.fetchone()[0] self.assert_(isinstance(x[0], str)) self.assertEqual(x, ['a', 'b', 'c']) def testUnicodeArray(self): psycopg2.extensions.register_type( psycopg2.extensions.UNICODEARRAY, self.conn) - curs = self.conn.cursor() - curs.execute("select '{a,b,c}'::text[]") - x = curs.fetchone()[0] + with self.conn.cursor() as curs: + curs.execute("select '{a,b,c}'::text[]") + x = curs.fetchone()[0] self.assert_(isinstance(x[0], text_type)) self.assertEqual(x, [u'a', u'b', u'c']) def testBytesArray(self): psycopg2.extensions.register_type( psycopg2.extensions.BYTESARRAY, self.conn) - curs = self.conn.cursor() - curs.execute("select '{a,b,c}'::text[]") - x = curs.fetchone()[0] + with self.conn.cursor() as curs: + curs.execute("select '{a,b,c}'::text[]") + x = curs.fetchone()[0] self.assert_(isinstance(x[0], bytes)) self.assertEqual(x, [b'a', b'b', b'c']) @testutils.skip_before_postgres(8, 2) def testArrayOfNulls(self): - curs = self.conn.cursor() - curs.execute(""" - create table na ( - texta text[], - inta int[], - boola boolean[], + with self.conn.cursor() as curs: + curs.execute(""" + create table na ( + texta text[], + inta int[], + boola boolean[], - textaa text[][], - intaa int[][], - boolaa boolean[][] - )""") + textaa text[][], + intaa int[][], + boolaa boolean[][] + )""") - curs.execute("insert into na (texta) values (%s)", ([None],)) - curs.execute("insert into na (texta) values (%s)", (['a', None],)) - curs.execute("insert into na (texta) values (%s)", ([None, None],)) - curs.execute("insert into na (inta) values (%s)", ([None],)) - curs.execute("insert into na (inta) values (%s)", ([42, None],)) - curs.execute("insert into na (inta) values (%s)", ([None, None],)) - curs.execute("insert into na (boola) values (%s)", ([None],)) - curs.execute("insert into na (boola) values (%s)", ([True, None],)) - curs.execute("insert into na (boola) values (%s)", ([None, None],)) + curs.execute("insert into na (texta) values (%s)", ([None],)) + curs.execute("insert into na (texta) values (%s)", (['a', None],)) + curs.execute("insert into na (texta) values (%s)", ([None, None],)) + curs.execute("insert into na (inta) values (%s)", ([None],)) + curs.execute("insert into na (inta) values (%s)", ([42, None],)) + curs.execute("insert into na (inta) values (%s)", ([None, None],)) + curs.execute("insert into na (boola) values (%s)", ([None],)) + curs.execute("insert into na (boola) values (%s)", ([True, None],)) + curs.execute("insert into na (boola) values (%s)", ([None, None],)) - curs.execute("insert into na (textaa) values (%s)", ([[None]],)) - curs.execute("insert into na (textaa) values (%s)", ([['a', None]],)) - curs.execute("insert into na (textaa) values (%s)", ([[None, None]],)) + curs.execute("insert into na (textaa) values (%s)", ([[None]],)) + curs.execute("insert into na (textaa) values (%s)", ([['a', None]],)) + curs.execute("insert into na (textaa) values (%s)", ([[None, None]],)) - curs.execute("insert into na (intaa) values (%s)", ([[None]],)) - curs.execute("insert into na (intaa) values (%s)", ([[42, None]],)) - curs.execute("insert into na (intaa) values (%s)", ([[None, None]],)) + curs.execute("insert into na (intaa) values (%s)", ([[None]],)) + curs.execute("insert into na (intaa) values (%s)", ([[42, None]],)) + curs.execute("insert into na (intaa) values (%s)", ([[None, None]],)) - curs.execute("insert into na (boolaa) values (%s)", ([[None]],)) - curs.execute("insert into na (boolaa) values (%s)", ([[True, None]],)) - curs.execute("insert into na (boolaa) values (%s)", ([[None, None]],)) + curs.execute("insert into na (boolaa) values (%s)", ([[None]],)) + curs.execute("insert into na (boolaa) values (%s)", ([[True, None]],)) + curs.execute("insert into na (boolaa) values (%s)", ([[None, None]],)) @testutils.skip_before_postgres(8, 2) def testNestedArrays(self): - curs = self.conn.cursor() - for a in [ - [[1]], - [[None]], - [[None, None, None]], - [[None, None], [1, None]], - [[None, None], [None, None]], - [[[None, None], [None, None]]], - ]: - curs.execute("select %s::int[]", (a,)) - self.assertEqual(curs.fetchone()[0], a) + with self.conn.cursor() as curs: + for a in [ + [[1]], + [[None]], + [[None, None, None]], + [[None, None], [1, None]], + [[None, None], [None, None]], + [[[None, None], [None, None]]], + ]: + curs.execute("select %s::int[]", (a,)) + self.assertEqual(curs.fetchone()[0], a) @testutils.skip_from_python(3) def testTypeRoundtripBuffer(self): diff --git a/tests/test_types_extras.py b/tests/test_types_extras.py index e4828077..081d12f2 100755 --- a/tests/test_types_extras.py +++ b/tests/test_types_extras.py @@ -45,9 +45,9 @@ class TypesExtrasTests(ConnectingTestCase): """Test that all type conversions are working.""" def execute(self, *args): - curs = self.conn.cursor() - curs.execute(*args) - return curs.fetchone()[0] + with self.conn.cursor() as curs: + curs.execute(*args) + return curs.fetchone()[0] @skip_if_no_uuid def testUUID(self): @@ -231,19 +231,19 @@ class HstoreTestCase(ConnectingTestCase): @skip_if_no_hstore def test_register_conn(self): register_hstore(self.conn) - cur = self.conn.cursor() - cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") - t = cur.fetchone() + with self.conn.cursor() as cur: + cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") + t = cur.fetchone() self.assert_(t[0] is None) self.assertEqual(t[1], {}) self.assertEqual(t[2], {'a': 'b'}) @skip_if_no_hstore def test_register_curs(self): - cur = self.conn.cursor() - register_hstore(cur) - cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") - t = cur.fetchone() + with self.conn.cursor() as cur: + register_hstore(cur) + cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") + t = cur.fetchone() self.assert_(t[0] is None) self.assertEqual(t[1], {}) self.assertEqual(t[2], {'a': 'b'}) @@ -252,9 +252,9 @@ class HstoreTestCase(ConnectingTestCase): @skip_from_python(3) def test_register_unicode(self): register_hstore(self.conn, unicode=True) - cur = self.conn.cursor() - cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") - t = cur.fetchone() + with self.conn.cursor() as cur: + cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") + t = cur.fetchone() self.assert_(t[0] is None) self.assertEqual(t[1], {}) self.assertEqual(t[2], {u'a': u'b'}) @@ -268,9 +268,9 @@ class HstoreTestCase(ConnectingTestCase): register_hstore(self.conn, globally=True) conn2 = self.connect() try: - cur2 = self.conn.cursor() - cur2.execute("select 'a => b'::hstore") - r = cur2.fetchone() + with self.conn.cursor() as cur2: + cur2.execute("select 'a => b'::hstore") + r = cur2.fetchone() self.assert_(isinstance(r[0], dict)) finally: conn2.close() @@ -278,67 +278,66 @@ class HstoreTestCase(ConnectingTestCase): @skip_if_no_hstore def test_roundtrip(self): register_hstore(self.conn) - cur = self.conn.cursor() + with self.conn.cursor() as cur: + def ok(d): + cur.execute("select %s", (d,)) + d1 = cur.fetchone()[0] + self.assertEqual(len(d), len(d1)) + for k in d: + self.assert_(k in d1, k) + self.assertEqual(d[k], d1[k]) - def ok(d): - cur.execute("select %s", (d,)) - d1 = cur.fetchone()[0] - self.assertEqual(len(d), len(d1)) - for k in d: - self.assert_(k in d1, k) - self.assertEqual(d[k], d1[k]) + ok({}) + ok({'a': 'b', 'c': None}) - ok({}) - ok({'a': 'b', 'c': None}) + ab = list(map(chr, range(32, 128))) + ok(dict(zip(ab, ab))) + ok({''.join(ab): ''.join(ab)}) - ab = list(map(chr, range(32, 128))) - ok(dict(zip(ab, ab))) - ok({''.join(ab): ''.join(ab)}) + self.conn.set_client_encoding('latin1') + if PY2: + ab = map(chr, range(32, 127) + range(160, 255)) + else: + ab = bytes( + list(range(32, 127)) + list(range(160, 255))).decode('latin1') - self.conn.set_client_encoding('latin1') - if PY2: - ab = map(chr, range(32, 127) + range(160, 255)) - else: - ab = bytes(list(range(32, 127)) + list(range(160, 255))).decode('latin1') - - ok({''.join(ab): ''.join(ab)}) - ok(dict(zip(ab, ab))) + ok({''.join(ab): ''.join(ab)}) + ok(dict(zip(ab, ab))) @skip_if_no_hstore @skip_from_python(3) def test_roundtrip_unicode(self): register_hstore(self.conn, unicode=True) - cur = self.conn.cursor() + with self.conn.cursor() as cur: + def ok(d): + cur.execute("select %s", (d,)) + d1 = cur.fetchone()[0] + self.assertEqual(len(d), len(d1)) + for k, v in d1.iteritems(): + self.assert_(k in d, k) + self.assertEqual(d[k], v) + self.assert_(isinstance(k, unicode)) + self.assert_(v is None or isinstance(v, unicode)) - def ok(d): - cur.execute("select %s", (d,)) - d1 = cur.fetchone()[0] - self.assertEqual(len(d), len(d1)) - for k, v in d1.iteritems(): - self.assert_(k in d, k) - self.assertEqual(d[k], v) - self.assert_(isinstance(k, unicode)) - self.assert_(v is None or isinstance(v, unicode)) + ok({}) + ok({'a': 'b', 'c': None, 'd': u'\u20ac', u'\u2603': 'e'}) - ok({}) - ok({'a': 'b', 'c': None, 'd': u'\u20ac', u'\u2603': 'e'}) - - ab = map(unichr, range(1, 1024)) - ok({u''.join(ab): u''.join(ab)}) - ok(dict(zip(ab, ab))) + ab = map(unichr, range(1, 1024)) + ok({u''.join(ab): u''.join(ab)}) + ok(dict(zip(ab, ab))) @skip_if_no_hstore @restore_types def test_oid(self): - cur = self.conn.cursor() - cur.execute("select 'hstore'::regtype::oid") - oid = cur.fetchone()[0] + with self.conn.cursor() as cur: + cur.execute("select 'hstore'::regtype::oid") + oid = cur.fetchone()[0] - # Note: None as conn_or_cursor is just for testing: not public - # interface and it may break in future. - register_hstore(None, globally=True, oid=oid) - cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") - t = cur.fetchone() + # Note: None as conn_or_cursor is just for testing: not public + # interface and it may break in future. + register_hstore(None, globally=True, oid=oid) + cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") + t = cur.fetchone() self.assert_(t[0] is None) self.assertEqual(t[1], {}) self.assertEqual(t[2], {'a': 'b'}) @@ -363,32 +362,32 @@ class HstoreTestCase(ConnectingTestCase): ds.append({''.join(ab): ''.join(ab)}) ds.append(dict(zip(ab, ab))) - cur = self.conn.cursor() - cur.execute("select %s", (ds,)) - ds1 = cur.fetchone()[0] + with self.conn.cursor() as cur: + cur.execute("select %s", (ds,)) + ds1 = cur.fetchone()[0] self.assertEqual(ds, ds1) @skip_if_no_hstore @skip_before_postgres(8, 3) def test_array_cast(self): register_hstore(self.conn) - cur = self.conn.cursor() - cur.execute("select array['a=>1'::hstore, 'b=>2'::hstore];") - a = cur.fetchone()[0] + with self.conn.cursor() as cur: + cur.execute("select array['a=>1'::hstore, 'b=>2'::hstore];") + a = cur.fetchone()[0] self.assertEqual(a, [{'a': '1'}, {'b': '2'}]) @skip_if_no_hstore @restore_types def test_array_cast_oid(self): - cur = self.conn.cursor() - cur.execute("select 'hstore'::regtype::oid, 'hstore[]'::regtype::oid") - oid, aoid = cur.fetchone() + with self.conn.cursor() as cur: + cur.execute("select 'hstore'::regtype::oid, 'hstore[]'::regtype::oid") + oid, aoid = cur.fetchone() - register_hstore(None, globally=True, oid=oid, array_oid=aoid) - cur.execute(""" - select null::hstore, ''::hstore, - 'a => b'::hstore, '{a=>b}'::hstore[]""") - t = cur.fetchone() + register_hstore(None, globally=True, oid=oid, array_oid=aoid) + cur.execute(""" + select null::hstore, ''::hstore, + 'a => b'::hstore, '{a=>b}'::hstore[]""") + t = cur.fetchone() self.assert_(t[0] is None) self.assertEqual(t[1], {}) self.assertEqual(t[2], {'a': 'b'}) @@ -399,18 +398,18 @@ class HstoreTestCase(ConnectingTestCase): conn = self.connect(connection_factory=RealDictConnection) try: register_hstore(conn) - curs = conn.cursor() - curs.execute("select ''::hstore as x") - self.assertEqual(curs.fetchone()['x'], {}) + with conn.cursor() as curs: + curs.execute("select ''::hstore as x") + self.assertEqual(curs.fetchone()['x'], {}) finally: conn.close() conn = self.connect(connection_factory=RealDictConnection) try: - curs = conn.cursor() - register_hstore(curs) - curs.execute("select ''::hstore as x") - self.assertEqual(curs.fetchone()['x'], {}) + with conn.cursor() as curs: + register_hstore(curs) + curs.execute("select ''::hstore as x") + self.assertEqual(curs.fetchone()['x'], {}) finally: conn.close() @@ -431,12 +430,12 @@ def skip_if_no_composite(f): class AdaptTypeTestCase(ConnectingTestCase): @skip_if_no_composite def test_none_in_record(self): - curs = self.conn.cursor() - s = curs.mogrify("SELECT %s;", [(42, None)]) - self.assertEqual(b"SELECT (42, NULL);", s) - curs.execute("SELECT %s;", [(42, None)]) - d = curs.fetchone()[0] - self.assertEqual("(42,)", d) + with self.conn.cursor() as curs: + s = curs.mogrify("SELECT %s;", [(42, None)]) + self.assertEqual(b"SELECT (42, NULL);", s) + curs.execute("SELECT %s;", [(42, None)]) + d = curs.fetchone()[0] + self.assertEqual("(42,)", d) def test_none_fast_path(self): # the None adapter is not actually invoked in regular adaptation @@ -448,18 +447,17 @@ class AdaptTypeTestCase(ConnectingTestCase): def getquoted(self): return "NOPE!" - curs = self.conn.cursor() + with self.conn.cursor() as curs: + orig_adapter = ext.adapters[type(None), ext.ISQLQuote] + try: + ext.register_adapter(type(None), WonkyAdapter) + self.assertEqual(ext.adapt(None).getquoted(), "NOPE!") - orig_adapter = ext.adapters[type(None), ext.ISQLQuote] - try: - ext.register_adapter(type(None), WonkyAdapter) - self.assertEqual(ext.adapt(None).getquoted(), "NOPE!") + s = curs.mogrify("SELECT %s;", (None,)) + self.assertEqual(b"SELECT NULL;", s) - s = curs.mogrify("SELECT %s;", (None,)) - self.assertEqual(b"SELECT NULL;", s) - - finally: - ext.register_adapter(type(None), orig_adapter) + finally: + ext.register_adapter(type(None), orig_adapter) def test_tokenization(self): def ok(s, v): @@ -502,10 +500,10 @@ class AdaptTypeTestCase(ConnectingTestCase): self.assertEqual(t.attnames, ['anint', 'astring', 'adate']) self.assertEqual(t.atttypes, [23, 25, 1082]) - curs = self.conn.cursor() - r = (10, 'hello', date(2011, 1, 2)) - curs.execute("select %s::type_isd;", (r,)) - v = curs.fetchone()[0] + with self.conn.cursor() as curs: + r = (10, 'hello', date(2011, 1, 2)) + curs.execute("select %s::type_isd;", (r,)) + v = curs.fetchone()[0] self.assert_(isinstance(v, t.type)) self.assertEqual(v[0], 10) self.assertEqual(v[1], "hello") @@ -519,21 +517,21 @@ class AdaptTypeTestCase(ConnectingTestCase): def test_empty_string(self): # issue #141 self._create_type("type_ss", [('s1', 'text'), ('s2', 'text')]) - curs = self.conn.cursor() - psycopg2.extras.register_composite("type_ss", curs) + with self.conn.cursor() as curs: + psycopg2.extras.register_composite("type_ss", curs) - def ok(t): - curs.execute("select %s::type_ss", (t,)) - rv = curs.fetchone()[0] - self.assertEqual(t, rv) + def ok(t): + curs.execute("select %s::type_ss", (t,)) + rv = curs.fetchone()[0] + self.assertEqual(t, rv) - ok(('a', 'b')) - ok(('a', '')) - ok(('', 'b')) - ok(('a', None)) - ok((None, 'b')) - ok(('', '')) - ok((None, None)) + ok(('a', 'b')) + ok(('a', '')) + ok(('', 'b')) + ok(('a', None)) + ok((None, 'b')) + ok(('', '')) + ok((None, None)) @skip_if_no_composite def test_cast_nested(self): @@ -548,10 +546,10 @@ class AdaptTypeTestCase(ConnectingTestCase): psycopg2.extras.register_composite("type_r_dt", self.conn) psycopg2.extras.register_composite("type_r_ft", self.conn) - curs = self.conn.cursor() - r = (0.25, (date(2011, 1, 2), (42, "hello"))) - curs.execute("select %s::type_r_ft;", (r,)) - v = curs.fetchone()[0] + with self.conn.cursor() as curs: + r = (0.25, (date(2011, 1, 2), (42, "hello"))) + curs.execute("select %s::type_r_ft;", (r,)) + v = curs.fetchone()[0] self.assertEqual(r, v) self.assertEqual(v.anotherpair.apair.astring, "hello") @@ -560,13 +558,12 @@ class AdaptTypeTestCase(ConnectingTestCase): def test_register_on_cursor(self): self._create_type("type_ii", [("a", "integer"), ("b", "integer")]) - curs1 = self.conn.cursor() - curs2 = self.conn.cursor() - psycopg2.extras.register_composite("type_ii", curs1) - curs1.execute("select (1,2)::type_ii") - self.assertEqual(curs1.fetchone()[0], (1, 2)) - curs2.execute("select (1,2)::type_ii") - self.assertEqual(curs2.fetchone()[0], "(1,2)") + with self.conn.cursor() as curs1, self.conn.cursor() as curs2: + psycopg2.extras.register_composite("type_ii", curs1) + curs1.execute("select (1,2)::type_ii") + self.assertEqual(curs1.fetchone()[0], (1, 2)) + curs2.execute("select (1,2)::type_ii") + self.assertEqual(curs2.fetchone()[0], "(1,2)") @skip_if_no_composite def test_register_on_connection(self): @@ -576,12 +573,11 @@ class AdaptTypeTestCase(ConnectingTestCase): conn2 = self.connect() try: psycopg2.extras.register_composite("type_ii", conn1) - curs1 = conn1.cursor() - curs2 = conn2.cursor() - curs1.execute("select (1,2)::type_ii") - self.assertEqual(curs1.fetchone()[0], (1, 2)) - curs2.execute("select (1,2)::type_ii") - self.assertEqual(curs2.fetchone()[0], "(1,2)") + with conn1.cursor() as curs1, conn2.cursor() as curs2: + curs1.execute("select (1,2)::type_ii") + self.assertEqual(curs1.fetchone()[0], (1, 2)) + curs2.execute("select (1,2)::type_ii") + self.assertEqual(curs2.fetchone()[0], "(1,2)") finally: conn1.close() conn2.close() @@ -595,12 +591,11 @@ class AdaptTypeTestCase(ConnectingTestCase): conn2 = self.connect() try: psycopg2.extras.register_composite("type_ii", conn1, globally=True) - curs1 = conn1.cursor() - curs2 = conn2.cursor() - curs1.execute("select (1,2)::type_ii") - self.assertEqual(curs1.fetchone()[0], (1, 2)) - curs2.execute("select (1,2)::type_ii") - self.assertEqual(curs2.fetchone()[0], (1, 2)) + with conn1.cursor() as curs1, conn2.cursor() as curs2: + curs1.execute("select (1,2)::type_ii") + self.assertEqual(curs1.fetchone()[0], (1, 2)) + curs2.execute("select (1,2)::type_ii") + self.assertEqual(curs2.fetchone()[0], (1, 2)) finally: conn1.close() @@ -608,22 +603,22 @@ class AdaptTypeTestCase(ConnectingTestCase): @skip_if_no_composite def test_composite_namespace(self): - curs = self.conn.cursor() - curs.execute(""" - select nspname from pg_namespace - where nspname = 'typens'; - """) - if not curs.fetchone(): - curs.execute("create schema typens;") - self.conn.commit() + with self.conn.cursor() as curs: + curs.execute(""" + select nspname from pg_namespace + where nspname = 'typens'; + """) + if not curs.fetchone(): + curs.execute("create schema typens;") + self.conn.commit() - self._create_type("typens.typens_ii", - [("a", "integer"), ("b", "integer")]) - t = psycopg2.extras.register_composite( - "typens.typens_ii", self.conn) - self.assertEqual(t.schema, 'typens') - curs.execute("select (4,8)::typens.typens_ii") - self.assertEqual(curs.fetchone()[0], (4, 8)) + self._create_type("typens.typens_ii", + [("a", "integer"), ("b", "integer")]) + t = psycopg2.extras.register_composite( + "typens.typens_ii", self.conn) + self.assertEqual(t.schema, 'typens') + curs.execute("select (4,8)::typens.typens_ii") + self.assertEqual(curs.fetchone()[0], (4, 8)) @skip_if_no_composite @skip_before_postgres(8, 4) @@ -633,11 +628,11 @@ class AdaptTypeTestCase(ConnectingTestCase): t = psycopg2.extras.register_composite("type_isd", self.conn) - curs = self.conn.cursor() - r1 = (10, 'hello', date(2011, 1, 2)) - r2 = (20, 'world', date(2011, 1, 3)) - curs.execute("select %s::type_isd[];", ([r1, r2],)) - v = curs.fetchone()[0] + with self.conn.cursor() as curs: + r1 = (10, 'hello', date(2011, 1, 2)) + r2 = (20, 'world', date(2011, 1, 3)) + curs.execute("select %s::type_isd[];", ([r1, r2],)) + v = curs.fetchone()[0] self.assertEqual(len(v), 2) self.assert_(isinstance(v[0], t.type)) self.assertEqual(v[0][0], 10) @@ -652,56 +647,56 @@ class AdaptTypeTestCase(ConnectingTestCase): def test_wrong_schema(self): oid = self._create_type("type_ii", [("a", "integer"), ("b", "integer")]) c = CompositeCaster('type_ii', oid, [('a', 23), ('b', 23), ('c', 23)]) - curs = self.conn.cursor() - psycopg2.extensions.register_type(c.typecaster, curs) - curs.execute("select (1,2)::type_ii") - self.assertRaises(psycopg2.DataError, curs.fetchone) + with self.conn.cursor() as curs: + psycopg2.extensions.register_type(c.typecaster, curs) + curs.execute("select (1,2)::type_ii") + self.assertRaises(psycopg2.DataError, curs.fetchone) @slow @skip_if_no_composite @skip_before_postgres(8, 4) def test_from_tables(self): - curs = self.conn.cursor() - curs.execute("""create table ctest1 ( - id integer primary key, - temp int, - label varchar - );""") + with self.conn.cursor() as curs: + curs.execute("""create table ctest1 ( + id integer primary key, + temp int, + label varchar + );""") - curs.execute("""alter table ctest1 drop temp;""") + curs.execute("""alter table ctest1 drop temp;""") - curs.execute("""create table ctest2 ( - id serial primary key, - label varchar, - test_id integer references ctest1(id) - );""") + curs.execute("""create table ctest2 ( + id serial primary key, + label varchar, + test_id integer references ctest1(id) + );""") - curs.execute("""insert into ctest1 (id, label) values - (1, 'test1'), - (2, 'test2');""") - curs.execute("""insert into ctest2 (label, test_id) values - ('testa', 1), - ('testb', 1), - ('testc', 2), - ('testd', 2);""") + curs.execute("""insert into ctest1 (id, label) values + (1, 'test1'), + (2, 'test2');""") + curs.execute("""insert into ctest2 (label, test_id) values + ('testa', 1), + ('testb', 1), + ('testc', 2), + ('testd', 2);""") - psycopg2.extras.register_composite("ctest1", curs) - psycopg2.extras.register_composite("ctest2", curs) + psycopg2.extras.register_composite("ctest1", curs) + psycopg2.extras.register_composite("ctest2", curs) - curs.execute(""" - select ctest1, array_agg(ctest2) as test2s - from ( - select ctest1, ctest2 - from ctest1 inner join ctest2 on ctest1.id = ctest2.test_id - order by ctest1.id, ctest2.label - ) x group by ctest1;""") + curs.execute(""" + select ctest1, array_agg(ctest2) as test2s + from ( + select ctest1, ctest2 + from ctest1 inner join ctest2 on ctest1.id = ctest2.test_id + order by ctest1.id, ctest2.label + ) x group by ctest1;""") - r = curs.fetchone() - self.assertEqual(r[0], (1, 'test1')) - self.assertEqual(r[1], [(1, 'testa', 1), (2, 'testb', 1)]) - r = curs.fetchone() - self.assertEqual(r[0], (2, 'test2')) - self.assertEqual(r[1], [(3, 'testc', 2), (4, 'testd', 2)]) + r = curs.fetchone() + self.assertEqual(r[0], (1, 'test1')) + self.assertEqual(r[1], [(1, 'testa', 1), (2, 'testb', 1)]) + r = curs.fetchone() + self.assertEqual(r[0], (2, 'test2')) + self.assertEqual(r[1], [(3, 'testc', 2), (4, 'testd', 2)]) @skip_if_no_composite def test_non_dbapi_connection(self): @@ -710,18 +705,18 @@ class AdaptTypeTestCase(ConnectingTestCase): conn = self.connect(connection_factory=RealDictConnection) try: register_composite('type_ii', conn) - curs = conn.cursor() - curs.execute("select '(1,2)'::type_ii as x") - self.assertEqual(curs.fetchone()['x'], (1, 2)) + with conn.cursor() as curs: + curs.execute("select '(1,2)'::type_ii as x") + self.assertEqual(curs.fetchone()['x'], (1, 2)) finally: conn.close() conn = self.connect(connection_factory=RealDictConnection) try: - curs = conn.cursor() - register_composite('type_ii', conn) - curs.execute("select '(1,2)'::type_ii as x") - self.assertEqual(curs.fetchone()['x'], (1, 2)) + with conn.cursor() as curs: + register_composite('type_ii', conn) + curs.execute("select '(1,2)'::type_ii as x") + self.assertEqual(curs.fetchone()['x'], (1, 2)) finally: conn.close() @@ -739,35 +734,35 @@ class AdaptTypeTestCase(ConnectingTestCase): self.assertEqual(t.name, 'type_isd') self.assertEqual(t.oid, oid) - curs = self.conn.cursor() - r = (10, 'hello', date(2011, 1, 2)) - curs.execute("select %s::type_isd;", (r,)) - v = curs.fetchone()[0] + with self.conn.cursor() as curs: + r = (10, 'hello', date(2011, 1, 2)) + curs.execute("select %s::type_isd;", (r,)) + v = curs.fetchone()[0] self.assert_(isinstance(v, dict)) self.assertEqual(v['anint'], 10) self.assertEqual(v['astring'], "hello") self.assertEqual(v['adate'], date(2011, 1, 2)) def _create_type(self, name, fields): - curs = self.conn.cursor() - try: - curs.execute("drop type %s cascade;" % name) - except psycopg2.ProgrammingError: - self.conn.rollback() + with self.conn.cursor() as curs: + try: + curs.execute("drop type %s cascade;" % name) + except psycopg2.ProgrammingError: + self.conn.rollback() - curs.execute("create type %s as (%s);" % (name, - ", ".join(["%s %s" % p for p in fields]))) - if '.' in name: - schema, name = name.split('.') - else: - schema = 'public' + curs.execute("create type %s as (%s);" % (name, + ", ".join(["%s %s" % p for p in fields]))) + if '.' in name: + schema, name = name.split('.') + else: + schema = 'public' - curs.execute("""\ - SELECT t.oid - FROM pg_type t JOIN pg_namespace ns ON typnamespace = ns.oid - WHERE typname = %s and nspname = %s; - """, (name, schema)) - oid = curs.fetchone()[0] + curs.execute("""\ + SELECT t.oid + FROM pg_type t JOIN pg_namespace ns ON typnamespace = ns.oid + WHERE typname = %s and nspname = %s; + """, (name, schema)) + oid = curs.fetchone()[0] self.conn.commit() return oid @@ -776,10 +771,10 @@ def skip_if_no_json_type(f): """Skip a test if PostgreSQL json type is not available""" @wraps(f) def skip_if_no_json_type_(self): - curs = self.conn.cursor() - curs.execute("select oid from pg_type where typname = 'json'") - if not curs.fetchone(): - return self.skipTest("json not available in test database") + with self.conn.cursor() as curs: + curs.execute("select oid from pg_type where typname = 'json'") + if not curs.fetchone(): + return self.skipTest("json not available in test database") return f(self) @@ -791,10 +786,10 @@ class JsonTestCase(ConnectingTestCase): objs = [None, "te'xt", 123, 123.45, u'\xe0\u20ac', ['a', 100], {'a': 100}] - curs = self.conn.cursor() - for obj in enumerate(objs): - self.assertQuotedEqual(curs.mogrify("%s", (Json(obj),)), - psycopg2.extensions.QuotedString(json.dumps(obj)).getquoted()) + with self.conn.cursor() as curs: + for obj in enumerate(objs): + self.assertQuotedEqual(curs.mogrify("%s", (Json(obj),)), + psycopg2.extensions.QuotedString(json.dumps(obj)).getquoted()) def test_adapt_dumps(self): class DecimalEncoder(json.JSONEncoder): @@ -803,13 +798,13 @@ class JsonTestCase(ConnectingTestCase): return float(obj) return json.JSONEncoder.default(self, obj) - curs = self.conn.cursor() - obj = Decimal('123.45') + with self.conn.cursor() as curs: + obj = Decimal('123.45') - def dumps(obj): - return json.dumps(obj, cls=DecimalEncoder) - self.assertQuotedEqual(curs.mogrify("%s", (Json(obj, dumps=dumps),)), - b"'123.45'") + def dumps(obj): + return json.dumps(obj, cls=DecimalEncoder) + self.assertQuotedEqual(curs.mogrify("%s", (Json(obj, dumps=dumps),)), + b"'123.45'") def test_adapt_subclass(self): class DecimalEncoder(json.JSONEncoder): @@ -822,59 +817,58 @@ class JsonTestCase(ConnectingTestCase): def dumps(self, obj): return json.dumps(obj, cls=DecimalEncoder) - curs = self.conn.cursor() - obj = Decimal('123.45') - self.assertQuotedEqual(curs.mogrify("%s", (MyJson(obj),)), b"'123.45'") + with self.conn.cursor() as curs: + obj = Decimal('123.45') + self.assertQuotedEqual(curs.mogrify("%s", (MyJson(obj),)), b"'123.45'") @restore_types def test_register_on_dict(self): psycopg2.extensions.register_adapter(dict, Json) - curs = self.conn.cursor() - obj = {'a': 123} - self.assertQuotedEqual( - curs.mogrify("%s", (obj,)), b"""'{"a": 123}'""") + with self.conn.cursor() as curs: + obj = {'a': 123} + self.assertQuotedEqual( + curs.mogrify("%s", (obj,)), b"""'{"a": 123}'""") def test_type_not_available(self): - curs = self.conn.cursor() - curs.execute("select oid from pg_type where typname = 'json'") - if curs.fetchone(): - return self.skipTest("json available in test database") + with self.conn.cursor() as curs: + curs.execute("select oid from pg_type where typname = 'json'") + if curs.fetchone(): + return self.skipTest("json available in test database") self.assertRaises(psycopg2.ProgrammingError, psycopg2.extras.register_json, self.conn) @skip_before_postgres(9, 2) def test_default_cast(self): - curs = self.conn.cursor() + with self.conn.cursor() as curs: + curs.execute("""select '{"a": 100.0, "b": null}'::json""") + self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) - curs.execute("""select '{"a": 100.0, "b": null}'::json""") - self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) - - curs.execute("""select array['{"a": 100.0, "b": null}']::json[]""") - self.assertEqual(curs.fetchone()[0], [{'a': 100.0, 'b': None}]) + curs.execute("""select array['{"a": 100.0, "b": null}']::json[]""") + self.assertEqual(curs.fetchone()[0], [{'a': 100.0, 'b': None}]) @skip_if_no_json_type def test_register_on_connection(self): psycopg2.extras.register_json(self.conn) - curs = self.conn.cursor() - curs.execute("""select '{"a": 100.0, "b": null}'::json""") - self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) + with self.conn.cursor() as curs: + curs.execute("""select '{"a": 100.0, "b": null}'::json""") + self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) @skip_if_no_json_type def test_register_on_cursor(self): - curs = self.conn.cursor() - psycopg2.extras.register_json(curs) - curs.execute("""select '{"a": 100.0, "b": null}'::json""") - self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) + with self.conn.cursor() as curs: + psycopg2.extras.register_json(curs) + curs.execute("""select '{"a": 100.0, "b": null}'::json""") + self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) @skip_if_no_json_type @restore_types def test_register_globally(self): new, newa = psycopg2.extras.register_json(self.conn, globally=True) - curs = self.conn.cursor() - curs.execute("""select '{"a": 100.0, "b": null}'::json""") - self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) + with self.conn.cursor() as curs: + curs.execute("""select '{"a": 100.0, "b": null}'::json""") + self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) @skip_if_no_json_type def test_loads(self): @@ -883,9 +877,9 @@ class JsonTestCase(ConnectingTestCase): def loads(s): return json.loads(s, parse_float=Decimal) psycopg2.extras.register_json(self.conn, loads=loads) - curs = self.conn.cursor() - curs.execute("""select '{"a": 100.0, "b": null}'::json""") - data = curs.fetchone()[0] + with self.conn.cursor() as curs: + curs.execute("""select '{"a": 100.0, "b": null}'::json""") + data = curs.fetchone()[0] self.assert_(isinstance(data['a'], Decimal)) self.assertEqual(data['a'], Decimal('100.0')) @@ -899,49 +893,48 @@ class JsonTestCase(ConnectingTestCase): new, newa = psycopg2.extras.register_json( loads=loads, oid=oid, array_oid=array_oid) - curs = self.conn.cursor() - curs.execute("""select '{"a": 100.0, "b": null}'::json""") - data = curs.fetchone()[0] + with self.conn.cursor() as curs: + curs.execute("""select '{"a": 100.0, "b": null}'::json""") + data = curs.fetchone()[0] self.assert_(isinstance(data['a'], Decimal)) self.assertEqual(data['a'], Decimal('100.0')) @skip_before_postgres(9, 2) def test_register_default(self): - curs = self.conn.cursor() + with self.conn.cursor() as curs: + def loads(s): + return psycopg2.extras.json.loads(s, parse_float=Decimal) + psycopg2.extras.register_default_json(curs, loads=loads) - def loads(s): - return psycopg2.extras.json.loads(s, parse_float=Decimal) - psycopg2.extras.register_default_json(curs, loads=loads) + curs.execute("""select '{"a": 100.0, "b": null}'::json""") + data = curs.fetchone()[0] + self.assert_(isinstance(data['a'], Decimal)) + self.assertEqual(data['a'], Decimal('100.0')) - curs.execute("""select '{"a": 100.0, "b": null}'::json""") - data = curs.fetchone()[0] - self.assert_(isinstance(data['a'], Decimal)) - self.assertEqual(data['a'], Decimal('100.0')) - - curs.execute("""select array['{"a": 100.0, "b": null}']::json[]""") - data = curs.fetchone()[0] - self.assert_(isinstance(data[0]['a'], Decimal)) - self.assertEqual(data[0]['a'], Decimal('100.0')) + curs.execute("""select array['{"a": 100.0, "b": null}']::json[]""") + data = curs.fetchone()[0] + self.assert_(isinstance(data[0]['a'], Decimal)) + self.assertEqual(data[0]['a'], Decimal('100.0')) @skip_if_no_json_type def test_null(self): psycopg2.extras.register_json(self.conn) - curs = self.conn.cursor() - curs.execute("""select NULL::json""") - self.assertEqual(curs.fetchone()[0], None) - curs.execute("""select NULL::json[]""") - self.assertEqual(curs.fetchone()[0], None) + with self.conn.cursor() as curs: + curs.execute("""select NULL::json""") + self.assertEqual(curs.fetchone()[0], None) + curs.execute("""select NULL::json[]""") + self.assertEqual(curs.fetchone()[0], None) def test_no_array_oid(self): - curs = self.conn.cursor() - t1, t2 = psycopg2.extras.register_json(curs, oid=25) - self.assertEqual(t1.values[0], 25) - self.assertEqual(t2, None) + with self.conn.cursor() as curs: + t1, t2 = psycopg2.extras.register_json(curs, oid=25) + self.assertEqual(t1.values[0], 25) + self.assertEqual(t2, None) - curs.execute("""select '{"a": 100.0, "b": null}'::text""") - data = curs.fetchone()[0] - self.assertEqual(data['a'], 100) - self.assertEqual(data['b'], None) + curs.execute("""select '{"a": 100.0, "b": null}'::text""") + data = curs.fetchone()[0] + self.assertEqual(data['a'], 100) + self.assertEqual(data['b'], None) def test_str(self): snowman = u"\u2603" @@ -956,20 +949,20 @@ class JsonTestCase(ConnectingTestCase): @skip_before_postgres(8, 2) def test_scs(self): cnn_on = self.connect(options="-c standard_conforming_strings=on") - cur_on = cnn_on.cursor() - self.assertEqual( - cur_on.mogrify("%s", [psycopg2.extras.Json({'a': '"'})]), - b'\'{"a": "\\""}\'') + with cnn_on.cursor() as cur_on: + self.assertEqual( + cur_on.mogrify("%s", [psycopg2.extras.Json({'a': '"'})]), + b'\'{"a": "\\""}\'') cnn_off = self.connect(options="-c standard_conforming_strings=off") - cur_off = cnn_off.cursor() - self.assertEqual( - cur_off.mogrify("%s", [psycopg2.extras.Json({'a': '"'})]), - b'E\'{"a": "\\\\""}\'') + with cnn_off.cursor() as cur_off: + self.assertEqual( + cur_off.mogrify("%s", [psycopg2.extras.Json({'a': '"'})]), + b'E\'{"a": "\\\\""}\'') - self.assertEqual( - cur_on.mogrify("%s", [psycopg2.extras.Json({'a': '"'})]), - b'\'{"a": "\\""}\'') + self.assertEqual( + cur_on.mogrify("%s", [psycopg2.extras.Json({'a': '"'})]), + b'\'{"a": "\\""}\'') def skip_if_no_jsonb_type(f): @@ -985,33 +978,32 @@ class JsonbTestCase(ConnectingTestCase): return rv def test_default_cast(self): - curs = self.conn.cursor() + with self.conn.cursor() as curs: + curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") + self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) - curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") - self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) - - curs.execute("""select array['{"a": 100.0, "b": null}']::jsonb[]""") - self.assertEqual(curs.fetchone()[0], [{'a': 100.0, 'b': None}]) + curs.execute("""select array['{"a": 100.0, "b": null}']::jsonb[]""") + self.assertEqual(curs.fetchone()[0], [{'a': 100.0, 'b': None}]) def test_register_on_connection(self): psycopg2.extras.register_json(self.conn, loads=self.myloads, name='jsonb') - curs = self.conn.cursor() - curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") - self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1}) + with self.conn.cursor() as curs: + curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") + self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1}) def test_register_on_cursor(self): - curs = self.conn.cursor() - psycopg2.extras.register_json(curs, loads=self.myloads, name='jsonb') - curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") - self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1}) + with self.conn.cursor() as curs: + psycopg2.extras.register_json(curs, loads=self.myloads, name='jsonb') + curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") + self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1}) @restore_types def test_register_globally(self): new, newa = psycopg2.extras.register_json(self.conn, loads=self.myloads, globally=True, name='jsonb') - curs = self.conn.cursor() - curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") - self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1}) + with self.conn.cursor() as curs: + curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") + self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1}) def test_loads(self): json = psycopg2.extras.json @@ -1020,41 +1012,40 @@ class JsonbTestCase(ConnectingTestCase): return json.loads(s, parse_float=Decimal) psycopg2.extras.register_json(self.conn, loads=loads, name='jsonb') - curs = self.conn.cursor() - curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") - data = curs.fetchone()[0] - self.assert_(isinstance(data['a'], Decimal)) - self.assertEqual(data['a'], Decimal('100.0')) - # sure we are not manling json too? - curs.execute("""select '{"a": 100.0, "b": null}'::json""") - data = curs.fetchone()[0] - self.assert_(isinstance(data['a'], float)) - self.assertEqual(data['a'], 100.0) + with self.conn.cursor() as curs: + curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") + data = curs.fetchone()[0] + self.assert_(isinstance(data['a'], Decimal)) + self.assertEqual(data['a'], Decimal('100.0')) + # sure we are not manling json too? + curs.execute("""select '{"a": 100.0, "b": null}'::json""") + data = curs.fetchone()[0] + self.assert_(isinstance(data['a'], float)) + self.assertEqual(data['a'], 100.0) def test_register_default(self): - curs = self.conn.cursor() + with self.conn.cursor() as curs: + def loads(s): + return psycopg2.extras.json.loads(s, parse_float=Decimal) - def loads(s): - return psycopg2.extras.json.loads(s, parse_float=Decimal) + psycopg2.extras.register_default_jsonb(curs, loads=loads) - psycopg2.extras.register_default_jsonb(curs, loads=loads) + curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") + data = curs.fetchone()[0] + self.assert_(isinstance(data['a'], Decimal)) + self.assertEqual(data['a'], Decimal('100.0')) - curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") - data = curs.fetchone()[0] - self.assert_(isinstance(data['a'], Decimal)) - self.assertEqual(data['a'], Decimal('100.0')) - - curs.execute("""select array['{"a": 100.0, "b": null}']::jsonb[]""") - data = curs.fetchone()[0] - self.assert_(isinstance(data[0]['a'], Decimal)) - self.assertEqual(data[0]['a'], Decimal('100.0')) + curs.execute("""select array['{"a": 100.0, "b": null}']::jsonb[]""") + data = curs.fetchone()[0] + self.assert_(isinstance(data[0]['a'], Decimal)) + self.assertEqual(data[0]['a'], Decimal('100.0')) def test_null(self): - curs = self.conn.cursor() - curs.execute("""select NULL::jsonb""") - self.assertEqual(curs.fetchone()[0], None) - curs.execute("""select NULL::jsonb[]""") - self.assertEqual(curs.fetchone()[0], None) + with self.conn.cursor() as curs: + curs.execute("""select NULL::jsonb""") + self.assertEqual(curs.fetchone()[0], None) + curs.execute("""select NULL::jsonb[]""") + self.assertEqual(curs.fetchone()[0], None) class RangeTestCase(unittest.TestCase): @@ -1332,59 +1323,59 @@ class RangeCasterTestCase(ConnectingTestCase): 'daterange', 'tsrange', 'tstzrange') def test_cast_null(self): - cur = self.conn.cursor() - for type in self.builtin_ranges: - cur.execute("select NULL::%s" % type) - r = cur.fetchone()[0] - self.assertEqual(r, None) + with self.conn.cursor() as cur: + for type in self.builtin_ranges: + cur.execute("select NULL::%s" % type) + r = cur.fetchone()[0] + self.assertEqual(r, None) def test_cast_empty(self): - cur = self.conn.cursor() - for type in self.builtin_ranges: - cur.execute("select 'empty'::%s" % type) - r = cur.fetchone()[0] - self.assert_(isinstance(r, Range), type) - self.assert_(r.isempty) + with self.conn.cursor() as cur: + for type in self.builtin_ranges: + cur.execute("select 'empty'::%s" % type) + r = cur.fetchone()[0] + self.assert_(isinstance(r, Range), type) + self.assert_(r.isempty) def test_cast_inf(self): - cur = self.conn.cursor() - for type in self.builtin_ranges: - cur.execute("select '(,)'::%s" % type) - r = cur.fetchone()[0] - self.assert_(isinstance(r, Range), type) - self.assert_(not r.isempty) - self.assert_(r.lower_inf) - self.assert_(r.upper_inf) + with self.conn.cursor() as cur: + for type in self.builtin_ranges: + cur.execute("select '(,)'::%s" % type) + r = cur.fetchone()[0] + self.assert_(isinstance(r, Range), type) + self.assert_(not r.isempty) + self.assert_(r.lower_inf) + self.assert_(r.upper_inf) def test_cast_numbers(self): - cur = self.conn.cursor() - for type in ('int4range', 'int8range'): - cur.execute("select '(10,20)'::%s" % type) + with self.conn.cursor() as cur: + for type in ('int4range', 'int8range'): + cur.execute("select '(10,20)'::%s" % type) + r = cur.fetchone()[0] + self.assert_(isinstance(r, NumericRange)) + self.assert_(not r.isempty) + self.assertEqual(r.lower, 11) + self.assertEqual(r.upper, 20) + self.assert_(not r.lower_inf) + self.assert_(not r.upper_inf) + self.assert_(r.lower_inc) + self.assert_(not r.upper_inc) + + cur.execute("select '(10.2,20.6)'::numrange") r = cur.fetchone()[0] self.assert_(isinstance(r, NumericRange)) self.assert_(not r.isempty) - self.assertEqual(r.lower, 11) - self.assertEqual(r.upper, 20) + self.assertEqual(r.lower, Decimal('10.2')) + self.assertEqual(r.upper, Decimal('20.6')) self.assert_(not r.lower_inf) self.assert_(not r.upper_inf) - self.assert_(r.lower_inc) + self.assert_(not r.lower_inc) self.assert_(not r.upper_inc) - cur.execute("select '(10.2,20.6)'::numrange") - r = cur.fetchone()[0] - self.assert_(isinstance(r, NumericRange)) - self.assert_(not r.isempty) - self.assertEqual(r.lower, Decimal('10.2')) - self.assertEqual(r.upper, Decimal('20.6')) - self.assert_(not r.lower_inf) - self.assert_(not r.upper_inf) - self.assert_(not r.lower_inc) - self.assert_(not r.upper_inc) - def test_cast_date(self): - cur = self.conn.cursor() - cur.execute("select '(2000-01-01,2012-12-31)'::daterange") - r = cur.fetchone()[0] + with self.conn.cursor() as cur: + cur.execute("select '(2000-01-01,2012-12-31)'::daterange") + r = cur.fetchone()[0] self.assert_(isinstance(r, DateRange)) self.assert_(not r.isempty) self.assertEqual(r.lower, date(2000, 1, 2)) @@ -1395,11 +1386,11 @@ class RangeCasterTestCase(ConnectingTestCase): self.assert_(not r.upper_inc) def test_cast_timestamp(self): - cur = self.conn.cursor() - ts1 = datetime(2000, 1, 1) - ts2 = datetime(2000, 12, 31, 23, 59, 59, 999) - cur.execute("select tsrange(%s, %s, '()')", (ts1, ts2)) - r = cur.fetchone()[0] + with self.conn.cursor() as cur: + ts1 = datetime(2000, 1, 1) + ts2 = datetime(2000, 12, 31, 23, 59, 59, 999) + cur.execute("select tsrange(%s, %s, '()')", (ts1, ts2)) + r = cur.fetchone()[0] self.assert_(isinstance(r, DateTimeRange)) self.assert_(not r.isempty) self.assertEqual(r.lower, ts1) @@ -1410,12 +1401,12 @@ class RangeCasterTestCase(ConnectingTestCase): self.assert_(not r.upper_inc) def test_cast_timestamptz(self): - cur = self.conn.cursor() - ts1 = datetime(2000, 1, 1, tzinfo=FixedOffsetTimezone(600)) - ts2 = datetime(2000, 12, 31, 23, 59, 59, 999, - tzinfo=FixedOffsetTimezone(600)) - cur.execute("select tstzrange(%s, %s, '[]')", (ts1, ts2)) - r = cur.fetchone()[0] + with self.conn.cursor() as cur: + ts1 = datetime(2000, 1, 1, tzinfo=FixedOffsetTimezone(600)) + ts2 = datetime(2000, 12, 31, 23, 59, 59, 999, + tzinfo=FixedOffsetTimezone(600)) + cur.execute("select tstzrange(%s, %s, '[]')", (ts1, ts2)) + r = cur.fetchone()[0] self.assert_(isinstance(r, DateTimeTZRange)) self.assert_(not r.isempty) self.assertEqual(r.lower, ts1) @@ -1426,202 +1417,199 @@ class RangeCasterTestCase(ConnectingTestCase): self.assert_(r.upper_inc) def test_adapt_number_range(self): - cur = self.conn.cursor() + with self.conn.cursor() as cur: + r = NumericRange(empty=True) + cur.execute("select %s::int4range", (r,)) + r1 = cur.fetchone()[0] + self.assert_(isinstance(r1, NumericRange)) + self.assert_(r1.isempty) - r = NumericRange(empty=True) - cur.execute("select %s::int4range", (r,)) - r1 = cur.fetchone()[0] - self.assert_(isinstance(r1, NumericRange)) - self.assert_(r1.isempty) + r = NumericRange(10, 20) + cur.execute("select %s::int8range", (r,)) + r1 = cur.fetchone()[0] + self.assert_(isinstance(r1, NumericRange)) + self.assertEqual(r1.lower, 10) + self.assertEqual(r1.upper, 20) + self.assert_(r1.lower_inc) + self.assert_(not r1.upper_inc) - r = NumericRange(10, 20) - cur.execute("select %s::int8range", (r,)) - r1 = cur.fetchone()[0] - self.assert_(isinstance(r1, NumericRange)) - self.assertEqual(r1.lower, 10) - self.assertEqual(r1.upper, 20) - self.assert_(r1.lower_inc) - self.assert_(not r1.upper_inc) - - r = NumericRange(Decimal('10.2'), Decimal('20.5'), '(]') - cur.execute("select %s::numrange", (r,)) - r1 = cur.fetchone()[0] - self.assert_(isinstance(r1, NumericRange)) - self.assertEqual(r1.lower, Decimal('10.2')) - self.assertEqual(r1.upper, Decimal('20.5')) - self.assert_(not r1.lower_inc) - self.assert_(r1.upper_inc) + r = NumericRange(Decimal('10.2'), Decimal('20.5'), '(]') + cur.execute("select %s::numrange", (r,)) + r1 = cur.fetchone()[0] + self.assert_(isinstance(r1, NumericRange)) + self.assertEqual(r1.lower, Decimal('10.2')) + self.assertEqual(r1.upper, Decimal('20.5')) + self.assert_(not r1.lower_inc) + self.assert_(r1.upper_inc) def test_adapt_numeric_range(self): - cur = self.conn.cursor() + with self.conn.cursor() as cur: + r = NumericRange(empty=True) + cur.execute("select %s::int4range", (r,)) + r1 = cur.fetchone()[0] + self.assert_(isinstance(r1, NumericRange), r1) + self.assert_(r1.isempty) - r = NumericRange(empty=True) - cur.execute("select %s::int4range", (r,)) - r1 = cur.fetchone()[0] - self.assert_(isinstance(r1, NumericRange), r1) - self.assert_(r1.isempty) + r = NumericRange(10, 20) + cur.execute("select %s::int8range", (r,)) + r1 = cur.fetchone()[0] + self.assert_(isinstance(r1, NumericRange)) + self.assertEqual(r1.lower, 10) + self.assertEqual(r1.upper, 20) + self.assert_(r1.lower_inc) + self.assert_(not r1.upper_inc) - r = NumericRange(10, 20) - cur.execute("select %s::int8range", (r,)) - r1 = cur.fetchone()[0] - self.assert_(isinstance(r1, NumericRange)) - self.assertEqual(r1.lower, 10) - self.assertEqual(r1.upper, 20) - self.assert_(r1.lower_inc) - self.assert_(not r1.upper_inc) - - r = NumericRange(Decimal('10.2'), Decimal('20.5'), '(]') - cur.execute("select %s::numrange", (r,)) - r1 = cur.fetchone()[0] - self.assert_(isinstance(r1, NumericRange)) - self.assertEqual(r1.lower, Decimal('10.2')) - self.assertEqual(r1.upper, Decimal('20.5')) - self.assert_(not r1.lower_inc) - self.assert_(r1.upper_inc) + r = NumericRange(Decimal('10.2'), Decimal('20.5'), '(]') + cur.execute("select %s::numrange", (r,)) + r1 = cur.fetchone()[0] + self.assert_(isinstance(r1, NumericRange)) + self.assertEqual(r1.lower, Decimal('10.2')) + self.assertEqual(r1.upper, Decimal('20.5')) + self.assert_(not r1.lower_inc) + self.assert_(r1.upper_inc) def test_adapt_date_range(self): - cur = self.conn.cursor() + with self.conn.cursor() as cur: + d1 = date(2012, 1, 1) + d2 = date(2012, 12, 31) + r = DateRange(d1, d2) + cur.execute("select %s", (r,)) + r1 = cur.fetchone()[0] + self.assert_(isinstance(r1, DateRange)) + self.assertEqual(r1.lower, d1) + self.assertEqual(r1.upper, d2) + self.assert_(r1.lower_inc) + self.assert_(not r1.upper_inc) - d1 = date(2012, 1, 1) - d2 = date(2012, 12, 31) - r = DateRange(d1, d2) - cur.execute("select %s", (r,)) - r1 = cur.fetchone()[0] - self.assert_(isinstance(r1, DateRange)) - self.assertEqual(r1.lower, d1) - self.assertEqual(r1.upper, d2) - self.assert_(r1.lower_inc) - self.assert_(not r1.upper_inc) + r = DateTimeRange(empty=True) + cur.execute("select %s", (r,)) + r1 = cur.fetchone()[0] + self.assert_(isinstance(r1, DateTimeRange)) + self.assert_(r1.isempty) - r = DateTimeRange(empty=True) - cur.execute("select %s", (r,)) - r1 = cur.fetchone()[0] - self.assert_(isinstance(r1, DateTimeRange)) - self.assert_(r1.isempty) - - ts1 = datetime(2000, 1, 1, tzinfo=FixedOffsetTimezone(600)) - ts2 = datetime(2000, 12, 31, 23, 59, 59, 999, - tzinfo=FixedOffsetTimezone(600)) - r = DateTimeTZRange(ts1, ts2, '(]') - cur.execute("select %s", (r,)) - r1 = cur.fetchone()[0] - self.assert_(isinstance(r1, DateTimeTZRange)) - self.assertEqual(r1.lower, ts1) - self.assertEqual(r1.upper, ts2) - self.assert_(not r1.lower_inc) - self.assert_(r1.upper_inc) + ts1 = datetime(2000, 1, 1, tzinfo=FixedOffsetTimezone(600)) + ts2 = datetime(2000, 12, 31, 23, 59, 59, 999, + tzinfo=FixedOffsetTimezone(600)) + r = DateTimeTZRange(ts1, ts2, '(]') + cur.execute("select %s", (r,)) + r1 = cur.fetchone()[0] + self.assert_(isinstance(r1, DateTimeTZRange)) + self.assertEqual(r1.lower, ts1) + self.assertEqual(r1.upper, ts2) + self.assert_(not r1.lower_inc) + self.assert_(r1.upper_inc) @restore_types def test_register_range_adapter(self): - cur = self.conn.cursor() - cur.execute("create type textrange as range (subtype=text)") - rc = register_range('textrange', 'TextRange', cur) + with self.conn.cursor() as cur: + cur.execute("create type textrange as range (subtype=text)") + rc = register_range('textrange', 'TextRange', cur) - TextRange = rc.range - self.assert_(issubclass(TextRange, Range)) - self.assertEqual(TextRange.__name__, 'TextRange') + TextRange = rc.range + self.assert_(issubclass(TextRange, Range)) + self.assertEqual(TextRange.__name__, 'TextRange') - r = TextRange('a', 'b', '(]') - cur.execute("select %s", (r,)) - r1 = cur.fetchone()[0] - self.assertEqual(r1.lower, 'a') - self.assertEqual(r1.upper, 'b') - self.assert_(not r1.lower_inc) - self.assert_(r1.upper_inc) - - cur.execute("select %s", ([r, r, r],)) - rs = cur.fetchone()[0] - self.assertEqual(len(rs), 3) - for r1 in rs: + r = TextRange('a', 'b', '(]') + cur.execute("select %s", (r,)) + r1 = cur.fetchone()[0] self.assertEqual(r1.lower, 'a') self.assertEqual(r1.upper, 'b') self.assert_(not r1.lower_inc) self.assert_(r1.upper_inc) + cur.execute("select %s", ([r, r, r],)) + rs = cur.fetchone()[0] + self.assertEqual(len(rs), 3) + for r1 in rs: + self.assertEqual(r1.lower, 'a') + self.assertEqual(r1.upper, 'b') + self.assert_(not r1.lower_inc) + self.assert_(r1.upper_inc) + def test_range_escaping(self): - cur = self.conn.cursor() - cur.execute("create type textrange as range (subtype=text)") - rc = register_range('textrange', 'TextRange', cur) + with self.conn.cursor() as cur: + cur.execute("create type textrange as range (subtype=text)") + rc = register_range('textrange', 'TextRange', cur) - TextRange = rc.range - cur.execute(""" - create table rangetest ( - id integer primary key, - range textrange)""") + TextRange = rc.range + cur.execute(""" + create table rangetest ( + id integer primary key, + range textrange)""") - bounds = ['[)', '(]', '()', '[]'] - ranges = [TextRange(low, up, bounds[i % 4]) - for i, (low, up) in enumerate(zip( - [None] + list(map(chr, range(1, 128))), - list(map(chr, range(1, 128))) + [None], - ))] - ranges.append(TextRange()) - ranges.append(TextRange(empty=True)) + bounds = ['[)', '(]', '()', '[]'] + ranges = [TextRange(low, up, bounds[i % 4]) + for i, (low, up) in enumerate(zip( + [None] + list(map(chr, range(1, 128))), + list(map(chr, range(1, 128))) + [None], + ))] + ranges.append(TextRange()) + ranges.append(TextRange(empty=True)) - errs = 0 - for i, r in enumerate(ranges): - # not all the ranges make sense: - # fun fact: select ascii('#') < ascii('$'), '#' < '$' - # yelds... t, f! At least in en_GB.UTF-8 collation. - # which seems suggesting a supremacy of the pound on the dollar. - # So some of these ranges will fail to insert. Be prepared but... - try: - cur.execute(""" - savepoint x; - insert into rangetest (id, range) values (%s, %s); - """, (i, r)) - except psycopg2.DataError: - errs += 1 - cur.execute("rollback to savepoint x;") + errs = 0 + for i, r in enumerate(ranges): + # not all the ranges make sense: + # fun fact: select ascii('#') < ascii('$'), '#' < '$' + # yelds... t, f! At least in en_GB.UTF-8 collation. + # which seems suggesting a supremacy of the pound on the dollar. + # So some of these ranges will fail to insert. Be prepared but... + try: + cur.execute(""" + savepoint x; + insert into rangetest (id, range) values (%s, %s); + """, (i, r)) + except psycopg2.DataError: + errs += 1 + cur.execute("rollback to savepoint x;") - # ...not too many errors! in the above collate there are 17 errors: - # assume in other collates we won't find more than 30 - self.assert_(errs < 30, - "too many collate errors. Is the test working?") + # ...not too many errors! in the above collate there are 17 errors: + # assume in other collates we won't find more than 30 + self.assert_(errs < 30, + "too many collate errors. Is the test working?") - cur.execute("select id, range from rangetest order by id") - for i, r in cur: - self.assertEqual(ranges[i].lower, r.lower) - self.assertEqual(ranges[i].upper, r.upper) - self.assertEqual(ranges[i].lower_inc, r.lower_inc) - self.assertEqual(ranges[i].upper_inc, r.upper_inc) - self.assertEqual(ranges[i].lower_inf, r.lower_inf) - self.assertEqual(ranges[i].upper_inf, r.upper_inf) + cur.execute("select id, range from rangetest order by id") + for i, r in cur: + self.assertEqual(ranges[i].lower, r.lower) + self.assertEqual(ranges[i].upper, r.upper) + self.assertEqual(ranges[i].lower_inc, r.lower_inc) + self.assertEqual(ranges[i].upper_inc, r.upper_inc) + self.assertEqual(ranges[i].lower_inf, r.lower_inf) + self.assertEqual(ranges[i].upper_inf, r.upper_inf) # clear the adapters to allow precise count by scripts/refcounter.py del ext.adapters[TextRange, ext.ISQLQuote] def test_range_not_found(self): - cur = self.conn.cursor() - self.assertRaises(psycopg2.ProgrammingError, - register_range, 'nosuchrange', 'FailRange', cur) + with self.conn.cursor() as cur: + self.assertRaises(psycopg2.ProgrammingError, + register_range, 'nosuchrange', 'FailRange', cur) @restore_types def test_schema_range(self): - cur = self.conn.cursor() - cur.execute("create schema rs") - cur.execute("create type r1 as range (subtype=text)") - cur.execute("create type r2 as range (subtype=text)") - cur.execute("create type rs.r2 as range (subtype=text)") - cur.execute("create type rs.r3 as range (subtype=text)") - cur.execute("savepoint x") + with self.conn.cursor() as cur: + cur.execute("create schema rs") + cur.execute("create type r1 as range (subtype=text)") + cur.execute("create type r2 as range (subtype=text)") + cur.execute("create type rs.r2 as range (subtype=text)") + cur.execute("create type rs.r3 as range (subtype=text)") + cur.execute("savepoint x") - register_range('r1', 'r1', cur) - ra2 = register_range('r2', 'r2', cur) - rars2 = register_range('rs.r2', 'r2', cur) - register_range('rs.r3', 'r3', cur) + register_range('r1', 'r1', cur) + ra2 = register_range('r2', 'r2', cur) + rars2 = register_range('rs.r2', 'r2', cur) + register_range('rs.r3', 'r3', cur) - self.assertNotEqual( - ra2.typecaster.values[0], - rars2.typecaster.values[0]) + self.assertNotEqual( + ra2.typecaster.values[0], + rars2.typecaster.values[0]) - self.assertRaises(psycopg2.ProgrammingError, - register_range, 'r3', 'FailRange', cur) - cur.execute("rollback to savepoint x;") + self.assertRaises(psycopg2.ProgrammingError, + register_range, 'r3', 'FailRange', cur) + cur.execute("rollback to savepoint x;") - self.assertRaises(psycopg2.ProgrammingError, - register_range, 'rs.r1', 'FailRange', cur) - cur.execute("rollback to savepoint x;") + self.assertRaises(psycopg2.ProgrammingError, + register_range, 'rs.r1', 'FailRange', cur) + cur.execute("rollback to savepoint x;") class TestSolveConnCurs(ConnectingTestCase): diff --git a/tests/test_with.py b/tests/test_with.py index b8c043f6..13173452 100755 --- a/tests/test_with.py +++ b/tests/test_with.py @@ -33,15 +33,15 @@ from .testutils import ConnectingTestCase, skip_before_postgres class WithTestCase(ConnectingTestCase): def setUp(self): ConnectingTestCase.setUp(self) - curs = self.conn.cursor() - try: - curs.execute("delete from test_with") - self.conn.commit() - except psycopg2.ProgrammingError: - # assume table doesn't exist - self.conn.rollback() - curs.execute("create table test_with (id integer primary key)") - self.conn.commit() + with self.conn.cursor() as curs: + try: + curs.execute("delete from test_with") + self.conn.commit() + except psycopg2.ProgrammingError: + # assume table doesn't exist + self.conn.rollback() + curs.execute("create table test_with (id integer primary key)") + self.conn.commit() class WithConnectionTestCase(WithTestCase): @@ -49,59 +49,59 @@ class WithConnectionTestCase(WithTestCase): with self.conn as conn: self.assert_(self.conn is conn) self.assertEqual(conn.status, ext.STATUS_READY) - curs = conn.cursor() - curs.execute("insert into test_with values (1)") + with conn.cursor() as curs: + curs.execute("insert into test_with values (1)") self.assertEqual(conn.status, ext.STATUS_BEGIN) self.assertEqual(self.conn.status, ext.STATUS_READY) self.assert_(not self.conn.closed) - curs = self.conn.cursor() - curs.execute("select * from test_with") - self.assertEqual(curs.fetchall(), [(1,)]) + with self.conn.cursor() as curs: + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), [(1,)]) def test_with_connect_idiom(self): with self.connect() as conn: self.assertEqual(conn.status, ext.STATUS_READY) - curs = conn.cursor() - curs.execute("insert into test_with values (2)") - self.assertEqual(conn.status, ext.STATUS_BEGIN) + with conn.cursor() as curs: + curs.execute("insert into test_with values (2)") + self.assertEqual(conn.status, ext.STATUS_BEGIN) self.assertEqual(self.conn.status, ext.STATUS_READY) self.assert_(not self.conn.closed) - curs = self.conn.cursor() - curs.execute("select * from test_with") - self.assertEqual(curs.fetchall(), [(2,)]) + with self.conn.cursor() as curs: + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), [(2,)]) def test_with_error_db(self): def f(): with self.conn as conn: - curs = conn.cursor() - curs.execute("insert into test_with values ('a')") + with conn.cursor() as curs: + curs.execute("insert into test_with values ('a')") self.assertRaises(psycopg2.DataError, f) self.assertEqual(self.conn.status, ext.STATUS_READY) self.assert_(not self.conn.closed) - curs = self.conn.cursor() - curs.execute("select * from test_with") - self.assertEqual(curs.fetchall(), []) + with self.conn.cursor() as curs: + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), []) def test_with_error_python(self): def f(): with self.conn as conn: - curs = conn.cursor() - curs.execute("insert into test_with values (3)") - 1 / 0 + with conn.cursor() as curs: + curs.execute("insert into test_with values (3)") + 1 / 0 self.assertRaises(ZeroDivisionError, f) self.assertEqual(self.conn.status, ext.STATUS_READY) self.assert_(not self.conn.closed) - curs = self.conn.cursor() - curs.execute("select * from test_with") - self.assertEqual(curs.fetchall(), []) + with self.conn.cursor() as curs: + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), []) def test_with_closed(self): def f(): @@ -120,15 +120,15 @@ class WithConnectionTestCase(WithTestCase): super(MyConn, self).commit() with self.connect(connection_factory=MyConn) as conn: - curs = conn.cursor() - curs.execute("insert into test_with values (10)") + with conn.cursor() as curs: + curs.execute("insert into test_with values (10)") self.assertEqual(conn.status, ext.STATUS_READY) self.assert_(commits) - curs = self.conn.cursor() - curs.execute("select * from test_with") - self.assertEqual(curs.fetchall(), [(10,)]) + with self.conn.cursor() as curs: + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), [(10,)]) def test_subclass_rollback(self): rollbacks = [] @@ -140,9 +140,9 @@ class WithConnectionTestCase(WithTestCase): try: with self.connect(connection_factory=MyConn) as conn: - curs = conn.cursor() - curs.execute("insert into test_with values (11)") - 1 / 0 + with conn.cursor() as curs: + curs.execute("insert into test_with values (11)") + 1 / 0 except ZeroDivisionError: pass else: @@ -151,9 +151,9 @@ class WithConnectionTestCase(WithTestCase): self.assertEqual(conn.status, ext.STATUS_READY) self.assert_(rollbacks) - curs = conn.cursor() - curs.execute("select * from test_with") - self.assertEqual(curs.fetchall(), []) + with conn.cursor() as curs: + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), []) class WithCursorTestCase(WithTestCase): @@ -168,9 +168,9 @@ class WithCursorTestCase(WithTestCase): self.assertEqual(self.conn.status, ext.STATUS_READY) self.assert_(not self.conn.closed) - curs = self.conn.cursor() - curs.execute("select * from test_with") - self.assertEqual(curs.fetchall(), [(4,)]) + with self.conn.cursor() as curs: + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), [(4,)]) def test_with_error(self): try: @@ -185,9 +185,9 @@ class WithCursorTestCase(WithTestCase): self.assert_(not self.conn.closed) self.assert_(curs.closed) - curs = self.conn.cursor() - curs.execute("select * from test_with") - self.assertEqual(curs.fetchall(), []) + with self.conn.cursor() as curs: + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), []) def test_subclass(self): closes = [] diff --git a/tests/testutils.py b/tests/testutils.py index cd69c468..b0770b34 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -228,9 +228,9 @@ def skip_if_no_uuid(f): @wraps(f) def skip_if_no_uuid_(self): try: - cur = self.conn.cursor() - cur.execute("select typname from pg_type where typname = 'uuid'") - has = cur.fetchone() + with self.conn.cursor() as cur: + cur.execute("select typname from pg_type where typname = 'uuid'") + has = cur.fetchone() finally: self.conn.rollback() @@ -249,14 +249,14 @@ def skip_if_tpc_disabled(f): def skip_if_tpc_disabled_(self): cnn = self.connect() try: - cur = cnn.cursor() - try: - cur.execute("SHOW max_prepared_transactions;") - except psycopg2.ProgrammingError: - return self.skipTest( - "server too old: two phase transactions not supported.") - else: - mtp = int(cur.fetchone()[0]) + with cnn.cursor() as cur: + try: + cur.execute("SHOW max_prepared_transactions;") + except psycopg2.ProgrammingError: + return self.skipTest( + "server too old: two phase transactions not supported.") + else: + mtp = int(cur.fetchone()[0]) finally: cnn.close()