Always close cursor objects in tests

This commit is contained in:
Jon Dufresne 2020-02-02 16:20:52 -08:00
parent 6b63fae20a
commit de858b9cb2
22 changed files with 2954 additions and 2953 deletions

View File

@ -61,7 +61,7 @@ class AsyncTests(ConnectingTestCase):
self.wait(self.conn) self.wait(self.conn)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute(''' curs.execute('''
CREATE TEMPORARY TABLE table1 ( CREATE TEMPORARY TABLE table1 (
id int PRIMARY KEY id int PRIMARY KEY
@ -71,6 +71,8 @@ class AsyncTests(ConnectingTestCase):
def test_connection_setup(self): def test_connection_setup(self):
cur = self.conn.cursor() cur = self.conn.cursor()
sync_cur = self.sync_conn.cursor() sync_cur = self.sync_conn.cursor()
cur.close()
sync_cur.close()
del cur, sync_cur del cur, sync_cur
self.assert_(self.conn.async_) self.assert_(self.conn.async_)
@ -90,7 +92,7 @@ class AsyncTests(ConnectingTestCase):
self.conn.cursor, "name") self.conn.cursor, "name")
def test_async_select(self): def test_async_select(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.assertFalse(self.conn.isexecuting()) self.assertFalse(self.conn.isexecuting())
cur.execute("select 'a'") cur.execute("select 'a'")
self.assertTrue(self.conn.isexecuting()) self.assertTrue(self.conn.isexecuting())
@ -103,7 +105,7 @@ class AsyncTests(ConnectingTestCase):
@slow @slow
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def test_async_callproc(self): def test_async_callproc(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.callproc("pg_sleep", (0.1, )) cur.callproc("pg_sleep", (0.1, ))
self.assertTrue(self.conn.isexecuting()) self.assertTrue(self.conn.isexecuting())
@ -113,8 +115,9 @@ class AsyncTests(ConnectingTestCase):
@slow @slow
def test_async_after_async(self): def test_async_after_async(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur2 = self.conn.cursor() cur2 = self.conn.cursor()
cur2.close()
del cur2 del cur2
cur.execute("insert into table1 values (1)") cur.execute("insert into table1 values (1)")
@ -141,7 +144,7 @@ class AsyncTests(ConnectingTestCase):
self.assertEquals(cur.fetchone(), None) self.assertEquals(cur.fetchone(), None)
def test_fetch_after_async(self): def test_fetch_after_async(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 'a'") cur.execute("select 'a'")
# a fetch after an asynchronous query should raise an error # a fetch after an asynchronous query should raise an error
@ -152,16 +155,14 @@ class AsyncTests(ConnectingTestCase):
self.assertEquals(cur.fetchall()[0][0], "a") self.assertEquals(cur.fetchall()[0][0], "a")
def test_rollback_while_async(self): def test_rollback_while_async(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 'a'") cur.execute("select 'a'")
# a rollback should not work in asynchronous mode # a rollback should not work in asynchronous mode
self.assertRaises(psycopg2.ProgrammingError, self.conn.rollback) self.assertRaises(psycopg2.ProgrammingError, self.conn.rollback)
def test_commit_while_async(self): def test_commit_while_async(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("begin") cur.execute("begin")
self.wait(cur) self.wait(cur)
@ -188,8 +189,7 @@ class AsyncTests(ConnectingTestCase):
self.assertEquals(cur.fetchone(), None) self.assertEquals(cur.fetchone(), None)
def test_set_parameters_while_async(self): def test_set_parameters_while_async(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 'c'") cur.execute("select 'c'")
self.assertTrue(self.conn.isexecuting()) self.assertTrue(self.conn.isexecuting())
@ -207,7 +207,7 @@ class AsyncTests(ConnectingTestCase):
self.conn.set_isolation_level, 1) self.conn.set_isolation_level, 1)
def test_reset_while_async(self): def test_reset_while_async(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 'c'") cur.execute("select 'c'")
self.assertTrue(self.conn.isexecuting()) self.assertTrue(self.conn.isexecuting())
@ -215,8 +215,7 @@ class AsyncTests(ConnectingTestCase):
self.assertRaises(psycopg2.ProgrammingError, self.conn.reset) self.assertRaises(psycopg2.ProgrammingError, self.conn.reset)
def test_async_iter(self): def test_async_iter(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("begin") cur.execute("begin")
self.wait(cur) self.wait(cur)
cur.execute(""" cur.execute("""
@ -236,7 +235,7 @@ class AsyncTests(ConnectingTestCase):
self.assertFalse(self.conn.isexecuting()) self.assertFalse(self.conn.isexecuting())
def test_copy_while_async(self): def test_copy_while_async(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 'a'") cur.execute("select 'a'")
# copy should fail # copy should fail
@ -250,13 +249,13 @@ class AsyncTests(ConnectingTestCase):
self.conn.lobject) self.conn.lobject)
def test_async_executemany(self): def test_async_executemany(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.assertRaises( self.assertRaises(
psycopg2.ProgrammingError, psycopg2.ProgrammingError,
cur.executemany, "insert into table1 values (%s)", [1, 2, 3]) cur.executemany, "insert into table1 values (%s)", [1, 2, 3])
def test_async_scroll(self): def test_async_scroll(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute(""" cur.execute("""
insert into table1 values (1); insert into table1 values (1);
insert into table1 values (2); insert into table1 values (2);
@ -274,16 +273,16 @@ class AsyncTests(ConnectingTestCase):
cur.scroll(1) cur.scroll(1)
self.assertEquals(cur.fetchall(), [(2, ), (3, )]) self.assertEquals(cur.fetchall(), [(2, ), (3, )])
cur = self.conn.cursor() with self.conn.cursor() as cur2:
cur.execute("select id from table1 order by id") cur.execute("select id from table1 order by id")
self.wait(cur) self.wait(cur)
cur2 = self.conn.cursor() with self.conn.cursor() as cur2:
self.assertRaises(psycopg2.ProgrammingError, cur2.scroll, 1) self.assertRaises(psycopg2.ProgrammingError, cur2.scroll, 1)
self.assertRaises(psycopg2.ProgrammingError, cur.scroll, 4) self.assertRaises(psycopg2.ProgrammingError, cur.scroll, 4)
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select id from table1 order by id") cur.execute("select id from table1 order by id")
self.wait(cur) self.wait(cur)
cur.scroll(2) cur.scroll(2)
@ -291,7 +290,7 @@ class AsyncTests(ConnectingTestCase):
self.assertEquals(cur.fetchall(), [(2, ), (3, )]) self.assertEquals(cur.fetchall(), [(2, ), (3, )])
def test_scroll(self): def test_scroll(self):
cur = self.sync_conn.cursor() with self.sync_conn.cursor() as cur:
cur.execute("create table table1 (id int)") cur.execute("create table table1 (id int)")
cur.execute(""" cur.execute("""
insert into table1 values (1); insert into table1 values (1);
@ -304,7 +303,7 @@ class AsyncTests(ConnectingTestCase):
self.assertEquals(cur.fetchall(), [(2, ), (3, )]) self.assertEquals(cur.fetchall(), [(2, ), (3, )])
def test_async_dont_read_all(self): def test_async_dont_read_all(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select repeat('a', 10000); select repeat('b', 10000)") cur.execute("select repeat('a', 10000); select repeat('b', 10000)")
# fetch the result # fetch the result
@ -326,7 +325,7 @@ class AsyncTests(ConnectingTestCase):
@slow @slow
def test_flush_on_write(self): def test_flush_on_write(self):
# a very large query requires a flush loop to be sent to the backend # a very large query requires a flush loop to be sent to the backend
curs = self.conn.cursor() with self.conn.cursor() as curs:
for mb in 1, 5, 10, 20, 50: for mb in 1, 5, 10, 20, 50:
size = mb * 1024 * 1024 size = mb * 1024 * 1024
stub = PollableStub(self.conn) stub = PollableStub(self.conn)
@ -343,7 +342,7 @@ class AsyncTests(ConnectingTestCase):
warnings.warn("sending a large query didn't trigger block on write.") warnings.warn("sending a large query didn't trigger block on write.")
def test_sync_poll(self): def test_sync_poll(self):
cur = self.sync_conn.cursor() with self.sync_conn.cursor() as cur:
cur.execute("select 1") cur.execute("select 1")
# polling with a sync query works # polling with a sync query works
cur.connection.poll() cur.connection.poll()
@ -351,9 +350,7 @@ class AsyncTests(ConnectingTestCase):
@slow @slow
def test_notify(self): def test_notify(self):
cur = self.conn.cursor() with self.conn.cursor() as cur, self.sync_conn.cursor() as sync_cur:
sync_cur = self.sync_conn.cursor()
sync_cur.execute("listen test_notify") sync_cur.execute("listen test_notify")
self.sync_conn.commit() self.sync_conn.commit()
cur.execute("notify test_notify") cur.execute("notify test_notify")
@ -374,8 +371,7 @@ class AsyncTests(ConnectingTestCase):
self.fail("No NOTIFY in 2.5 seconds") self.fail("No NOTIFY in 2.5 seconds")
def test_async_fetch_wrong_cursor(self): def test_async_fetch_wrong_cursor(self):
cur1 = self.conn.cursor() with self.conn.cursor() as cur1, self.conn.cursor() as cur2:
cur2 = self.conn.cursor()
cur1.execute("select 1") cur1.execute("select 1")
self.wait(cur1) self.wait(cur1)
@ -386,7 +382,7 @@ class AsyncTests(ConnectingTestCase):
self.assertEquals(cur1.fetchone()[0], 1) self.assertEquals(cur1.fetchone()[0], 1)
def test_error(self): def test_error(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("insert into table1 values (%s)", (1, )) cur.execute("insert into table1 values (%s)", (1, ))
self.wait(cur) self.wait(cur)
cur.execute("insert into table1 values (%s)", (1, )) cur.execute("insert into table1 values (%s)", (1, ))
@ -409,7 +405,7 @@ class AsyncTests(ConnectingTestCase):
self.wait(cur) self.wait(cur)
def test_stop_on_first_error(self): def test_stop_on_first_error(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 1; select x; select 1/0; select 2") cur.execute("select 1; select x; select 1/0; select 2")
self.assertRaises(psycopg2.errors.UndefinedColumn, self.wait, cur) self.assertRaises(psycopg2.errors.UndefinedColumn, self.wait, cur)
@ -418,8 +414,7 @@ class AsyncTests(ConnectingTestCase):
self.assertEqual(cur.fetchone(), (1,)) self.assertEqual(cur.fetchone(), (1,))
def test_error_two_cursors(self): def test_error_two_cursors(self):
cur = self.conn.cursor() with self.conn.cursor() as cur, self.conn.cursor() as cur2:
cur2 = self.conn.cursor()
cur.execute("select * from no_such_table") cur.execute("select * from no_such_table")
self.assertRaises(psycopg2.ProgrammingError, self.wait, cur) self.assertRaises(psycopg2.ProgrammingError, self.wait, cur)
cur2.execute("select 1") cur2.execute("select 1")
@ -428,7 +423,7 @@ class AsyncTests(ConnectingTestCase):
def test_notices(self): def test_notices(self):
del self.conn.notices[:] del self.conn.notices[:]
cur = self.conn.cursor() with self.conn.cursor() as cur:
if self.conn.info.server_version >= 90300: if self.conn.info.server_version >= 90300:
cur.execute("set client_min_messages=debug1") cur.execute("set client_min_messages=debug1")
self.wait(cur) self.wait(cur)
@ -438,14 +433,14 @@ class AsyncTests(ConnectingTestCase):
self.assert_(self.conn.notices) self.assert_(self.conn.notices)
def test_async_cursor_gone(self): def test_async_cursor_gone(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 42;") cur.execute("select 42;")
del cur del cur
gc.collect() gc.collect()
self.assertRaises(psycopg2.InterfaceError, self.wait, self.conn) self.assertRaises(psycopg2.InterfaceError, self.wait, self.conn)
# The connection is still usable # The connection is still usable
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 42;") cur.execute("select 42;")
self.wait(self.conn) self.wait(self.conn)
self.assertEqual(cur.fetchone(), (42,)) self.assertEqual(cur.fetchone(), (42,))
@ -464,7 +459,7 @@ class AsyncTests(ConnectingTestCase):
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def test_copy_no_hang(self): def test_copy_no_hang(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("copy (select 1) to stdout") cur.execute("copy (select 1) to stdout")
self.assertRaises(psycopg2.ProgrammingError, self.wait, self.conn) self.assertRaises(psycopg2.ProgrammingError, self.wait, self.conn)
@ -473,7 +468,7 @@ class AsyncTests(ConnectingTestCase):
def test_non_block_after_notification(self): def test_non_block_after_notification(self):
from select import select from select import select
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute(""" cur.execute("""
select 1; select 1;
do $$ do $$

View File

@ -47,7 +47,7 @@ class AsyncTests(ConnectingTestCase):
self.wait(self.conn) self.wait(self.conn)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute(''' curs.execute('''
CREATE TEMPORARY TABLE table1 ( CREATE TEMPORARY TABLE table1 (
id int PRIMARY KEY id int PRIMARY KEY
@ -57,6 +57,8 @@ class AsyncTests(ConnectingTestCase):
def test_connection_setup(self): def test_connection_setup(self):
cur = self.conn.cursor() cur = self.conn.cursor()
sync_cur = self.sync_conn.cursor() sync_cur = self.sync_conn.cursor()
cur.close()
sync_cur.close()
del cur, sync_cur del cur, sync_cur
self.assert_(self.conn.async) self.assert_(self.conn.async)
@ -97,7 +99,7 @@ class CancelTests(ConnectingTestCase):
def setUp(self): def setUp(self):
ConnectingTestCase.setUp(self) ConnectingTestCase.setUp(self)
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute(''' cur.execute('''
CREATE TEMPORARY TABLE table1 ( CREATE TEMPORARY TABLE table1 (
id int PRIMARY KEY id int PRIMARY KEY
@ -110,7 +112,7 @@ class CancelTests(ConnectingTestCase):
async_conn = psycopg2.connect(dsn, async=True) async_conn = psycopg2.connect(dsn, async=True)
self.assertRaises(psycopg2.OperationalError, async_conn.cancel) self.assertRaises(psycopg2.OperationalError, async_conn.cancel)
extras.wait_select(async_conn) extras.wait_select(async_conn)
cur = async_conn.cursor() with async_conn.cursor() as cur:
cur.execute("select pg_sleep(10)") cur.execute("select pg_sleep(10)")
time.sleep(1) time.sleep(1)
self.assertTrue(async_conn.isexecuting()) self.assertTrue(async_conn.isexecuting())
@ -183,8 +185,7 @@ class AsyncReplicationTest(ReplicationTestCase):
if conn is None: if conn is None:
return return
cur = conn.cursor() with conn.cursor() as cur:
self.create_replication_slot(cur, output_plugin='test_decoding') self.create_replication_slot(cur, output_plugin='test_decoding')
self.wait(cur) self.wait(cur)

View File

@ -39,7 +39,7 @@ class StolenReferenceTestCase(ConnectingTestCase):
return 42 return 42
UUID = psycopg2.extensions.new_type((2950,), "UUID", fish) UUID = psycopg2.extensions.new_type((2950,), "UUID", fish)
psycopg2.extensions.register_type(UUID, self.conn) psycopg2.extensions.register_type(UUID, self.conn)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select 'b5219e01-19ab-4994-b71e-149225dc51e4'::uuid") curs.execute("select 'b5219e01-19ab-4994-b71e-149225dc51e4'::uuid")
curs.fetchone() curs.fetchone()

View File

@ -41,7 +41,7 @@ class CancelTests(ConnectingTestCase):
def setUp(self): def setUp(self):
ConnectingTestCase.setUp(self) ConnectingTestCase.setUp(self)
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute(''' cur.execute('''
CREATE TEMPORARY TABLE table1 ( CREATE TEMPORARY TABLE table1 (
id int PRIMARY KEY id int PRIMARY KEY
@ -57,7 +57,7 @@ class CancelTests(ConnectingTestCase):
errors = [] errors = []
def neverending(conn): def neverending(conn):
cur = conn.cursor() with conn.cursor() as cur:
try: try:
self.assertRaises(psycopg2.extensions.QueryCanceledError, self.assertRaises(psycopg2.extensions.QueryCanceledError,
cur.execute, "select pg_sleep(60)") cur.execute, "select pg_sleep(60)")
@ -70,7 +70,7 @@ class CancelTests(ConnectingTestCase):
raise raise
def canceller(conn): def canceller(conn):
cur = conn.cursor() with conn.cursor() as cur:
try: try:
conn.cancel() conn.cancel()
except Exception as e: except Exception as e:
@ -95,7 +95,7 @@ class CancelTests(ConnectingTestCase):
async_conn = psycopg2.connect(dsn, async_=True) async_conn = psycopg2.connect(dsn, async_=True)
self.assertRaises(psycopg2.OperationalError, async_conn.cancel) self.assertRaises(psycopg2.OperationalError, async_conn.cancel)
extras.wait_select(async_conn) extras.wait_select(async_conn)
cur = async_conn.cursor() with async_conn.cursor() as cur:
cur.execute("select pg_sleep(10)") cur.execute("select pg_sleep(10)")
time.sleep(1) time.sleep(1)
self.assertTrue(async_conn.isexecuting()) self.assertTrue(async_conn.isexecuting())

View File

@ -80,7 +80,7 @@ class ConnectionTests(ConnectingTestCase):
def test_cleanup_on_badconn_close(self): def test_cleanup_on_badconn_close(self):
# ticket #148 # ticket #148
conn = self.conn conn = self.conn
cur = conn.cursor() with conn.cursor() as cur:
self.assertRaises(psycopg2.OperationalError, self.assertRaises(psycopg2.OperationalError,
cur.execute, "select pg_terminate_backend(pg_backend_pid())") cur.execute, "select pg_terminate_backend(pg_backend_pid())")
@ -113,7 +113,7 @@ class ConnectionTests(ConnectingTestCase):
def test_notices(self): def test_notices(self):
conn = self.conn conn = self.conn
cur = conn.cursor() with conn.cursor() as cur:
if self.conn.info.server_version >= 90300: if self.conn.info.server_version >= 90300:
cur.execute("set client_min_messages=debug1") cur.execute("set client_min_messages=debug1")
cur.execute("create temp table chatty (id serial primary key);") cur.execute("create temp table chatty (id serial primary key);")
@ -122,7 +122,7 @@ class ConnectionTests(ConnectingTestCase):
def test_notices_consistent_order(self): def test_notices_consistent_order(self):
conn = self.conn conn = self.conn
cur = conn.cursor() with conn.cursor() as cur:
if self.conn.info.server_version >= 90300: if self.conn.info.server_version >= 90300:
cur.execute("set client_min_messages=debug1") cur.execute("set client_min_messages=debug1")
cur.execute(""" cur.execute("""
@ -142,7 +142,7 @@ class ConnectionTests(ConnectingTestCase):
@slow @slow
def test_notices_limited(self): def test_notices_limited(self):
conn = self.conn conn = self.conn
cur = conn.cursor() with conn.cursor() as cur:
if self.conn.info.server_version >= 90300: if self.conn.info.server_version >= 90300:
cur.execute("set client_min_messages=debug1") cur.execute("set client_min_messages=debug1")
for i in range(0, 100, 10): for i in range(0, 100, 10):
@ -157,7 +157,7 @@ class ConnectionTests(ConnectingTestCase):
def test_notices_deque(self): def test_notices_deque(self):
conn = self.conn conn = self.conn
self.conn.notices = deque() self.conn.notices = deque()
cur = conn.cursor() with conn.cursor() as cur:
if self.conn.info.server_version >= 90300: if self.conn.info.server_version >= 90300:
cur.execute("set client_min_messages=debug1") cur.execute("set client_min_messages=debug1")
@ -187,7 +187,7 @@ class ConnectionTests(ConnectingTestCase):
def test_notices_noappend(self): def test_notices_noappend(self):
conn = self.conn conn = self.conn
self.conn.notices = None # will make an error swallowes ok self.conn.notices = None # will make an error swallowes ok
cur = conn.cursor() with conn.cursor() as cur:
if self.conn.info.server_version >= 90300: if self.conn.info.server_version >= 90300:
cur.execute("set client_min_messages=debug1") cur.execute("set client_min_messages=debug1")
@ -233,7 +233,7 @@ class ConnectionTests(ConnectingTestCase):
def test_encoding_name(self): def test_encoding_name(self):
self.conn.set_client_encoding("EUC_JP") self.conn.set_client_encoding("EUC_JP")
# conn.encoding is 'EUCJP' now. # conn.encoding is 'EUCJP' now.
cur = self.conn.cursor() with self.conn.cursor() as cur:
ext.register_type(ext.UNICODE, cur) ext.register_type(ext.UNICODE, cur)
cur.execute("select 'foo'::text;") cur.execute("select 'foo'::text;")
self.assertEqual(cur.fetchone()[0], u'foo') self.assertEqual(cur.fetchone()[0], u'foo')
@ -281,7 +281,7 @@ class ConnectionTests(ConnectingTestCase):
while conn.notices: while conn.notices:
notices.append((2, conn.notices.pop())) notices.append((2, conn.notices.pop()))
cur = conn.cursor() with conn.cursor() as cur:
t1 = threading.Thread(target=committer) t1 = threading.Thread(target=committer)
t1.start() t1.start()
for i in range(1000): for i in range(1000):
@ -297,36 +297,36 @@ class ConnectionTests(ConnectingTestCase):
def test_connect_cursor_factory(self): def test_connect_cursor_factory(self):
conn = self.connect(cursor_factory=psycopg2.extras.DictCursor) conn = self.connect(cursor_factory=psycopg2.extras.DictCursor)
cur = conn.cursor() with conn.cursor() as cur:
cur.execute("select 1 as a") cur.execute("select 1 as a")
self.assertEqual(cur.fetchone()['a'], 1) self.assertEqual(cur.fetchone()['a'], 1)
def test_cursor_factory(self): def test_cursor_factory(self):
self.assertEqual(self.conn.cursor_factory, None) self.assertEqual(self.conn.cursor_factory, None)
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 1 as a") cur.execute("select 1 as a")
self.assertRaises(TypeError, (lambda r: r['a']), cur.fetchone()) self.assertRaises(TypeError, (lambda r: r['a']), cur.fetchone())
self.conn.cursor_factory = psycopg2.extras.DictCursor self.conn.cursor_factory = psycopg2.extras.DictCursor
self.assertEqual(self.conn.cursor_factory, psycopg2.extras.DictCursor) self.assertEqual(self.conn.cursor_factory, psycopg2.extras.DictCursor)
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 1 as a") cur.execute("select 1 as a")
self.assertEqual(cur.fetchone()['a'], 1) self.assertEqual(cur.fetchone()['a'], 1)
self.conn.cursor_factory = None self.conn.cursor_factory = None
self.assertEqual(self.conn.cursor_factory, None) self.assertEqual(self.conn.cursor_factory, None)
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 1 as a") cur.execute("select 1 as a")
self.assertRaises(TypeError, (lambda r: r['a']), cur.fetchone()) self.assertRaises(TypeError, (lambda r: r['a']), cur.fetchone())
def test_cursor_factory_none(self): def test_cursor_factory_none(self):
# issue #210 # issue #210
conn = self.connect() conn = self.connect()
cur = conn.cursor(cursor_factory=None) with conn.cursor(cursor_factory=None) as cur:
self.assertEqual(type(cur), ext.cursor) self.assertEqual(type(cur), ext.cursor)
conn = self.connect(cursor_factory=psycopg2.extras.DictCursor) conn = self.connect(cursor_factory=psycopg2.extras.DictCursor)
cur = conn.cursor(cursor_factory=None) with conn.cursor(cursor_factory=None) as cur:
self.assertEqual(type(cur), psycopg2.extras.DictCursor) self.assertEqual(type(cur), psycopg2.extras.DictCursor)
def test_failed_init_status(self): def test_failed_init_status(self):
@ -583,8 +583,7 @@ class IsolationLevelsTestCase(ConnectingTestCase):
def test_set_isolation_level(self): def test_set_isolation_level(self):
conn = self.connect() conn = self.connect()
curs = conn.cursor() with conn.cursor() as curs:
levels = [ levels = [
('read uncommitted', ('read uncommitted',
ext.ISOLATION_LEVEL_READ_UNCOMMITTED), ext.ISOLATION_LEVEL_READ_UNCOMMITTED),
@ -615,8 +614,7 @@ class IsolationLevelsTestCase(ConnectingTestCase):
def test_set_isolation_level_autocommit(self): def test_set_isolation_level_autocommit(self):
conn = self.connect() conn = self.connect()
curs = conn.cursor() with conn.cursor() as curs:
conn.set_isolation_level(ext.ISOLATION_LEVEL_AUTOCOMMIT) conn.set_isolation_level(ext.ISOLATION_LEVEL_AUTOCOMMIT)
self.assertEqual(conn.isolation_level, ext.ISOLATION_LEVEL_DEFAULT) self.assertEqual(conn.isolation_level, ext.ISOLATION_LEVEL_DEFAULT)
self.assert_(conn.autocommit) self.assert_(conn.autocommit)
@ -630,8 +628,7 @@ class IsolationLevelsTestCase(ConnectingTestCase):
def test_set_isolation_level_default(self): def test_set_isolation_level_default(self):
conn = self.connect() conn = self.connect()
curs = conn.cursor() with conn.cursor() as curs:
conn.autocommit = True conn.autocommit = True
curs.execute("set default_transaction_isolation to 'read committed'") curs.execute("set default_transaction_isolation to 'read committed'")
@ -649,8 +646,7 @@ class IsolationLevelsTestCase(ConnectingTestCase):
def test_set_isolation_level_abort(self): def test_set_isolation_level_abort(self):
conn = self.connect() conn = self.connect()
cur = conn.cursor() with conn.cursor() as cur:
self.assertEqual(ext.TRANSACTION_STATUS_IDLE, self.assertEqual(ext.TRANSACTION_STATUS_IDLE,
conn.info.transaction_status) conn.info.transaction_status)
cur.execute("insert into isolevel values (10);") cur.execute("insert into isolevel values (10);")
@ -691,12 +687,12 @@ class IsolationLevelsTestCase(ConnectingTestCase):
cnn2 = self.connect() cnn2 = self.connect()
cnn2.set_isolation_level(ext.ISOLATION_LEVEL_AUTOCOMMIT) cnn2.set_isolation_level(ext.ISOLATION_LEVEL_AUTOCOMMIT)
cur1 = cnn1.cursor() with cnn1.cursor() as cur1:
cur1.execute("select count(*) from isolevel;") cur1.execute("select count(*) from isolevel;")
self.assertEqual(0, cur1.fetchone()[0]) self.assertEqual(0, cur1.fetchone()[0])
cnn1.commit() cnn1.commit()
cur2 = cnn2.cursor() with cnn2.cursor() as cur2:
cur2.execute("insert into isolevel values (10);") cur2.execute("insert into isolevel values (10);")
cur1.execute("select count(*) from isolevel;") cur1.execute("select count(*) from isolevel;")
@ -707,12 +703,12 @@ class IsolationLevelsTestCase(ConnectingTestCase):
cnn2 = self.connect() cnn2 = self.connect()
cnn2.set_isolation_level(ext.ISOLATION_LEVEL_READ_COMMITTED) cnn2.set_isolation_level(ext.ISOLATION_LEVEL_READ_COMMITTED)
cur1 = cnn1.cursor() with cnn1.cursor() as cur1:
cur1.execute("select count(*) from isolevel;") cur1.execute("select count(*) from isolevel;")
self.assertEqual(0, cur1.fetchone()[0]) self.assertEqual(0, cur1.fetchone()[0])
cnn1.commit() cnn1.commit()
cur2 = cnn2.cursor() with cnn2.cursor() as cur2:
cur2.execute("insert into isolevel values (10);") cur2.execute("insert into isolevel values (10);")
cur1.execute("insert into isolevel values (20);") cur1.execute("insert into isolevel values (20);")
@ -733,12 +729,12 @@ class IsolationLevelsTestCase(ConnectingTestCase):
cnn2 = self.connect() cnn2 = self.connect()
cnn2.set_isolation_level(ext.ISOLATION_LEVEL_SERIALIZABLE) cnn2.set_isolation_level(ext.ISOLATION_LEVEL_SERIALIZABLE)
cur1 = cnn1.cursor() with cnn1.cursor() as cur1:
cur1.execute("select count(*) from isolevel;") cur1.execute("select count(*) from isolevel;")
self.assertEqual(0, cur1.fetchone()[0]) self.assertEqual(0, cur1.fetchone()[0])
cnn1.commit() cnn1.commit()
cur2 = cnn2.cursor() with cnn2.cursor() as cur2:
cur2.execute("insert into isolevel values (10);") cur2.execute("insert into isolevel values (10);")
cur1.execute("insert into isolevel values (20);") cur1.execute("insert into isolevel values (20);")
@ -766,7 +762,7 @@ class IsolationLevelsTestCase(ConnectingTestCase):
cnn.set_isolation_level, 1) cnn.set_isolation_level, 1)
def test_setattr_isolation_level_int(self): def test_setattr_isolation_level_int(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.conn.isolation_level = ext.ISOLATION_LEVEL_SERIALIZABLE self.conn.isolation_level = ext.ISOLATION_LEVEL_SERIALIZABLE
self.assertEqual(self.conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE) self.assertEqual(self.conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE)
@ -814,7 +810,7 @@ class IsolationLevelsTestCase(ConnectingTestCase):
self.assertEqual(cur.fetchone()[0], isol) self.assertEqual(cur.fetchone()[0], isol)
def test_setattr_isolation_level_str(self): def test_setattr_isolation_level_str(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.conn.isolation_level = "serializable" self.conn.isolation_level = "serializable"
self.assertEqual(self.conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE) self.assertEqual(self.conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE)
@ -946,7 +942,7 @@ class ConnectionTwoPhaseTests(ConnectingTestCase):
cnn.tpc_begin(xid) cnn.tpc_begin(xid)
self.assertEqual(cnn.status, ext.STATUS_BEGIN) self.assertEqual(cnn.status, ext.STATUS_BEGIN)
cur = cnn.cursor() with cnn.cursor() as cur:
cur.execute("insert into test_tpc values ('test_tpc_commit');") cur.execute("insert into test_tpc values ('test_tpc_commit');")
self.assertEqual(0, self.count_xacts()) self.assertEqual(0, self.count_xacts())
self.assertEqual(0, self.count_test_records()) self.assertEqual(0, self.count_test_records())
@ -969,7 +965,7 @@ class ConnectionTwoPhaseTests(ConnectingTestCase):
cnn.tpc_begin(xid) cnn.tpc_begin(xid)
self.assertEqual(cnn.status, ext.STATUS_BEGIN) self.assertEqual(cnn.status, ext.STATUS_BEGIN)
cur = cnn.cursor() with cnn.cursor() as cur:
cur.execute("insert into test_tpc values ('test_tpc_commit_1p');") cur.execute("insert into test_tpc values ('test_tpc_commit_1p');")
self.assertEqual(0, self.count_xacts()) self.assertEqual(0, self.count_xacts())
self.assertEqual(0, self.count_test_records()) self.assertEqual(0, self.count_test_records())
@ -1013,7 +1009,7 @@ class ConnectionTwoPhaseTests(ConnectingTestCase):
cnn.tpc_begin(xid) cnn.tpc_begin(xid)
self.assertEqual(cnn.status, ext.STATUS_BEGIN) self.assertEqual(cnn.status, ext.STATUS_BEGIN)
cur = cnn.cursor() with cnn.cursor() as cur:
cur.execute("insert into test_tpc values ('test_tpc_rollback');") cur.execute("insert into test_tpc values ('test_tpc_rollback');")
self.assertEqual(0, self.count_xacts()) self.assertEqual(0, self.count_xacts())
self.assertEqual(0, self.count_test_records()) self.assertEqual(0, self.count_test_records())
@ -1036,7 +1032,7 @@ class ConnectionTwoPhaseTests(ConnectingTestCase):
cnn.tpc_begin(xid) cnn.tpc_begin(xid)
self.assertEqual(cnn.status, ext.STATUS_BEGIN) self.assertEqual(cnn.status, ext.STATUS_BEGIN)
cur = cnn.cursor() with cnn.cursor() as cur:
cur.execute("insert into test_tpc values ('test_tpc_rollback_1p');") cur.execute("insert into test_tpc values ('test_tpc_rollback_1p');")
self.assertEqual(0, self.count_xacts()) self.assertEqual(0, self.count_xacts())
self.assertEqual(0, self.count_test_records()) self.assertEqual(0, self.count_test_records())
@ -1054,7 +1050,7 @@ class ConnectionTwoPhaseTests(ConnectingTestCase):
cnn.tpc_begin(xid) cnn.tpc_begin(xid)
self.assertEqual(cnn.status, ext.STATUS_BEGIN) self.assertEqual(cnn.status, ext.STATUS_BEGIN)
cur = cnn.cursor() with cnn.cursor() as cur:
cur.execute("insert into test_tpc values ('test_tpc_commit_rec');") cur.execute("insert into test_tpc values ('test_tpc_commit_rec');")
self.assertEqual(0, self.count_xacts()) self.assertEqual(0, self.count_xacts())
self.assertEqual(0, self.count_test_records()) self.assertEqual(0, self.count_test_records())
@ -1078,7 +1074,7 @@ class ConnectionTwoPhaseTests(ConnectingTestCase):
cnn.tpc_recover() cnn.tpc_recover()
self.assertEqual(ext.STATUS_READY, cnn.status) self.assertEqual(ext.STATUS_READY, cnn.status)
cur = cnn.cursor() with cnn.cursor() as cur:
cur.execute("select 1") cur.execute("select 1")
self.assertEqual(ext.STATUS_BEGIN, cnn.status) self.assertEqual(ext.STATUS_BEGIN, cnn.status)
cnn.tpc_recover() cnn.tpc_recover()
@ -1088,7 +1084,7 @@ class ConnectionTwoPhaseTests(ConnectingTestCase):
# insert a few test xns # insert a few test xns
cnn = self.connect() cnn = self.connect()
cnn.set_isolation_level(0) cnn.set_isolation_level(0)
cur = cnn.cursor() with cnn.cursor() as cur:
cur.execute("begin; prepare transaction '1-foo';") cur.execute("begin; prepare transaction '1-foo';")
cur.execute("begin; prepare transaction '2-bar';") cur.execute("begin; prepare transaction '2-bar';")
@ -1121,7 +1117,7 @@ class ConnectionTwoPhaseTests(ConnectingTestCase):
cnn.tpc_prepare() cnn.tpc_prepare()
cnn = self.connect() cnn = self.connect()
cur = cnn.cursor() with cnn.cursor() as cur:
cur.execute("select gid from pg_prepared_xacts where database = %s;", cur.execute("select gid from pg_prepared_xacts where database = %s;",
(dbname,)) (dbname,))
self.assertEqual('42_Z3RyaWQ=_YnF1YWw=', cur.fetchone()[0]) self.assertEqual('42_Z3RyaWQ=_YnF1YWw=', cur.fetchone()[0])
@ -1248,14 +1244,14 @@ class TransactionControlTests(ConnectingTestCase):
ext.ISOLATION_LEVEL_SERIALIZABLE) ext.ISOLATION_LEVEL_SERIALIZABLE)
def test_not_in_transaction(self): def test_not_in_transaction(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 1") cur.execute("select 1")
self.assertRaises(psycopg2.ProgrammingError, self.assertRaises(psycopg2.ProgrammingError,
self.conn.set_session, self.conn.set_session,
ext.ISOLATION_LEVEL_SERIALIZABLE) ext.ISOLATION_LEVEL_SERIALIZABLE)
def test_set_isolation_level(self): def test_set_isolation_level(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.conn.set_session( self.conn.set_session(
ext.ISOLATION_LEVEL_SERIALIZABLE) ext.ISOLATION_LEVEL_SERIALIZABLE)
cur.execute("SHOW transaction_isolation;") cur.execute("SHOW transaction_isolation;")
@ -1287,7 +1283,7 @@ class TransactionControlTests(ConnectingTestCase):
self.conn.rollback() self.conn.rollback()
def test_set_isolation_level_str(self): def test_set_isolation_level_str(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.conn.set_session("serializable") self.conn.set_session("serializable")
cur.execute("SHOW transaction_isolation;") cur.execute("SHOW transaction_isolation;")
self.assertEqual(cur.fetchone()[0], 'serializable') self.assertEqual(cur.fetchone()[0], 'serializable')
@ -1322,7 +1318,7 @@ class TransactionControlTests(ConnectingTestCase):
def test_set_read_only(self): def test_set_read_only(self):
self.assert_(self.conn.readonly is None) self.assert_(self.conn.readonly is None)
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.conn.set_session(readonly=True) self.conn.set_session(readonly=True)
self.assert_(self.conn.readonly is True) self.assert_(self.conn.readonly is True)
cur.execute("SHOW transaction_read_only;") cur.execute("SHOW transaction_read_only;")
@ -1339,7 +1335,7 @@ class TransactionControlTests(ConnectingTestCase):
self.conn.rollback() self.conn.rollback()
def test_setattr_read_only(self): def test_setattr_read_only(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.conn.readonly = True self.conn.readonly = True
self.assert_(self.conn.readonly is True) self.assert_(self.conn.readonly is True)
cur.execute("SHOW transaction_read_only;") cur.execute("SHOW transaction_read_only;")
@ -1352,7 +1348,7 @@ class TransactionControlTests(ConnectingTestCase):
self.assertEqual(cur.fetchone()[0], 'on') self.assertEqual(cur.fetchone()[0], 'on')
self.conn.rollback() self.conn.rollback()
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.conn.readonly = None self.conn.readonly = None
self.assert_(self.conn.readonly is None) self.assert_(self.conn.readonly is None)
cur.execute("SHOW transaction_read_only;") cur.execute("SHOW transaction_read_only;")
@ -1366,7 +1362,7 @@ class TransactionControlTests(ConnectingTestCase):
self.conn.rollback() self.conn.rollback()
def test_set_default(self): def test_set_default(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("SHOW transaction_isolation;") cur.execute("SHOW transaction_isolation;")
isolevel = cur.fetchone()[0] isolevel = cur.fetchone()[0]
cur.execute("SHOW transaction_read_only;") cur.execute("SHOW transaction_read_only;")
@ -1384,7 +1380,7 @@ class TransactionControlTests(ConnectingTestCase):
@skip_before_postgres(9, 1) @skip_before_postgres(9, 1)
def test_set_deferrable(self): def test_set_deferrable(self):
self.assert_(self.conn.deferrable is None) self.assert_(self.conn.deferrable is None)
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.conn.set_session(readonly=True, deferrable=True) self.conn.set_session(readonly=True, deferrable=True)
self.assert_(self.conn.deferrable is True) self.assert_(self.conn.deferrable is True)
cur.execute("SHOW transaction_read_only;") cur.execute("SHOW transaction_read_only;")
@ -1413,7 +1409,7 @@ class TransactionControlTests(ConnectingTestCase):
@skip_before_postgres(9, 1) @skip_before_postgres(9, 1)
def test_setattr_deferrable(self): def test_setattr_deferrable(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.conn.deferrable = True self.conn.deferrable = True
self.assert_(self.conn.deferrable is True) self.assert_(self.conn.deferrable is True)
cur.execute("SHOW transaction_deferrable;") cur.execute("SHOW transaction_deferrable;")
@ -1426,7 +1422,7 @@ class TransactionControlTests(ConnectingTestCase):
self.assertEqual(cur.fetchone()[0], 'on') self.assertEqual(cur.fetchone()[0], 'on')
self.conn.rollback() self.conn.rollback()
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.conn.deferrable = None self.conn.deferrable = None
self.assert_(self.conn.deferrable is None) self.assert_(self.conn.deferrable is None)
cur.execute("SHOW transaction_deferrable;") cur.execute("SHOW transaction_deferrable;")
@ -1440,7 +1436,7 @@ class TransactionControlTests(ConnectingTestCase):
self.conn.rollback() self.conn.rollback()
def test_mixing_session_attribs(self): def test_mixing_session_attribs(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.conn.autocommit = True self.conn.autocommit = True
self.conn.readonly = True self.conn.readonly = True
@ -1463,7 +1459,7 @@ class TransactionControlTests(ConnectingTestCase):
self.conn.autocommit = True self.conn.autocommit = True
self.conn.readonly = True self.conn.readonly = True
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("SHOW transaction_read_only") cur.execute("SHOW transaction_read_only")
self.assertEqual(cur.fetchone()[0], 'on') self.assertEqual(cur.fetchone()[0], 'on')
@ -1486,7 +1482,7 @@ class TestEncryptPassword(ConnectingTestCase):
@skip_before_libpq(10) @skip_before_libpq(10)
@skip_before_postgres(10) @skip_before_postgres(10)
def test_encrypt_server(self): def test_encrypt_server(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("SHOW password_encryption;") cur.execute("SHOW password_encryption;")
server_encryption_algorithm = cur.fetchone()[0] server_encryption_algorithm = cur.fetchone()[0]
@ -1568,7 +1564,7 @@ class AutocommitTests(ConnectingTestCase):
self.assertEqual(self.conn.info.transaction_status, self.assertEqual(self.conn.info.transaction_status,
ext.TRANSACTION_STATUS_IDLE) ext.TRANSACTION_STATUS_IDLE)
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute('select 1;') cur.execute('select 1;')
self.assertEqual(self.conn.status, ext.STATUS_BEGIN) self.assertEqual(self.conn.status, ext.STATUS_BEGIN)
self.assertEqual(self.conn.info.transaction_status, self.assertEqual(self.conn.info.transaction_status,
@ -1586,7 +1582,7 @@ class AutocommitTests(ConnectingTestCase):
self.assertEqual(self.conn.info.transaction_status, self.assertEqual(self.conn.info.transaction_status,
ext.TRANSACTION_STATUS_IDLE) ext.TRANSACTION_STATUS_IDLE)
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute('select 1;') cur.execute('select 1;')
self.assertEqual(self.conn.status, ext.STATUS_READY) self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assertEqual(self.conn.info.transaction_status, self.assertEqual(self.conn.info.transaction_status,
@ -1604,7 +1600,7 @@ class AutocommitTests(ConnectingTestCase):
ext.TRANSACTION_STATUS_INTRANS) ext.TRANSACTION_STATUS_INTRANS)
def test_set_intrans_error(self): def test_set_intrans_error(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute('select 1;') cur.execute('select 1;')
self.assertRaises(psycopg2.ProgrammingError, self.assertRaises(psycopg2.ProgrammingError,
setattr, self.conn, 'autocommit', True) setattr, self.conn, 'autocommit', True)
@ -1616,7 +1612,7 @@ class AutocommitTests(ConnectingTestCase):
self.assertEqual(self.conn.info.transaction_status, self.assertEqual(self.conn.info.transaction_status,
ext.TRANSACTION_STATUS_IDLE) ext.TRANSACTION_STATUS_IDLE)
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute('select 1;') cur.execute('select 1;')
self.assertEqual(self.conn.status, ext.STATUS_READY) self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assertEqual(self.conn.info.transaction_status, self.assertEqual(self.conn.info.transaction_status,
@ -1762,7 +1758,7 @@ class TestConnectionInfo(ConnectingTestCase):
self.assert_(self.bconn.info.dbname is None) self.assert_(self.bconn.info.dbname is None)
def test_user(self): def test_user(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select user") cur.execute("select user")
self.assertEqual(self.conn.info.user, cur.fetchone()[0]) self.assertEqual(self.conn.info.user, cur.fetchone()[0])
self.assert_(self.bconn.info.user is None) self.assert_(self.bconn.info.user is None)
@ -1801,13 +1797,13 @@ class TestConnectionInfo(ConnectingTestCase):
def test_transaction_status(self): def test_transaction_status(self):
self.assertEqual(self.conn.info.transaction_status, 0) self.assertEqual(self.conn.info.transaction_status, 0)
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 1") cur.execute("select 1")
self.assertEqual(self.conn.info.transaction_status, 2) self.assertEqual(self.conn.info.transaction_status, 2)
self.assertEqual(self.bconn.info.transaction_status, 4) self.assertEqual(self.bconn.info.transaction_status, 4)
def test_parameter_status(self): def test_parameter_status(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
try: try:
cur.execute("show server_version") cur.execute("show server_version")
except psycopg2.DatabaseError: except psycopg2.DatabaseError:
@ -1826,7 +1822,7 @@ class TestConnectionInfo(ConnectingTestCase):
self.assertEqual(self.bconn.info.protocol_version, 0) self.assertEqual(self.bconn.info.protocol_version, 0)
def test_server_version(self): def test_server_version(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
try: try:
cur.execute("show server_version_num") cur.execute("show server_version_num")
except psycopg2.DatabaseError: except psycopg2.DatabaseError:
@ -1841,7 +1837,7 @@ class TestConnectionInfo(ConnectingTestCase):
self.assertIsNone(self.conn.info.error_message) self.assertIsNone(self.conn.info.error_message)
self.assertIsNotNone(self.bconn.info.error_message) self.assertIsNotNone(self.bconn.info.error_message)
cur = self.conn.cursor() with self.conn.cursor() as cur:
try: try:
cur.execute("select 1 from nosuchtable") cur.execute("select 1 from nosuchtable")
except psycopg2.DatabaseError: except psycopg2.DatabaseError:
@ -1854,7 +1850,7 @@ class TestConnectionInfo(ConnectingTestCase):
self.assert_(self.bconn.info.socket < 0) self.assert_(self.bconn.info.socket < 0)
def test_backend_pid(self): def test_backend_pid(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
try: try:
cur.execute("select pg_backend_pid()") cur.execute("select pg_backend_pid()")
except psycopg2.DatabaseError: except psycopg2.DatabaseError:

View File

@ -66,7 +66,7 @@ class CopyTests(ConnectingTestCase):
self._create_temp_table() self._create_temp_table()
def _create_temp_table(self): def _create_temp_table(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute(''' curs.execute('''
CREATE TEMPORARY TABLE tcopy ( CREATE TEMPORARY TABLE tcopy (
id serial PRIMARY KEY, id serial PRIMARY KEY,
@ -92,7 +92,7 @@ class CopyTests(ConnectingTestCase):
curs.close() curs.close()
def test_copy_from_cols(self): def test_copy_from_cols(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
f = StringIO() f = StringIO()
for i in range(10): for i in range(10):
f.write("%s\n" % (i,)) f.write("%s\n" % (i,))
@ -104,7 +104,7 @@ class CopyTests(ConnectingTestCase):
self.assertEqual([(i, None) for i in range(10)], curs.fetchall()) self.assertEqual([(i, None) for i in range(10)], curs.fetchall())
def test_copy_from_cols_err(self): def test_copy_from_cols_err(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
f = StringIO() f = StringIO()
for i in range(10): for i in range(10):
f.write("%s\n" % (i,)) f.write("%s\n" % (i,))
@ -140,7 +140,7 @@ class CopyTests(ConnectingTestCase):
+ list(range(160, 256))).decode('latin1') + list(range(160, 256))).decode('latin1')
about = abin.replace('\\', '\\\\') about = abin.replace('\\', '\\\\')
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute('insert into tcopy values (%s, %s)', curs.execute('insert into tcopy values (%s, %s)',
(42, abin)) (42, abin))
@ -161,7 +161,7 @@ class CopyTests(ConnectingTestCase):
+ list(range(160, 255))).decode('latin1') + list(range(160, 255))).decode('latin1')
about = abin.replace('\\', '\\\\').encode('latin1') about = abin.replace('\\', '\\\\').encode('latin1')
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute('insert into tcopy values (%s, %s)', curs.execute('insert into tcopy values (%s, %s)',
(42, abin)) (42, abin))
@ -188,7 +188,7 @@ class CopyTests(ConnectingTestCase):
f.write(about) f.write(about)
f.seek(0) f.seek(0)
curs = self.conn.cursor() with self.conn.cursor() as curs:
psycopg2.extensions.register_type( psycopg2.extensions.register_type(
psycopg2.extensions.UNICODE, curs) psycopg2.extensions.UNICODE, curs)
@ -254,14 +254,14 @@ class CopyTests(ConnectingTestCase):
pass pass
f = Whatever() f = Whatever()
curs = self.conn.cursor() with self.conn.cursor() as curs:
self.assertRaises(TypeError, self.assertRaises(TypeError,
curs.copy_expert, 'COPY tcopy (data) FROM STDIN', f) curs.copy_expert, 'COPY tcopy (data) FROM STDIN', f)
def test_copy_no_column_limit(self): def test_copy_no_column_limit(self):
cols = ["c%050d" % i for i in range(200)] cols = ["c%050d" % i for i in range(200)]
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute('CREATE TEMPORARY TABLE manycols (%s)' % ',\n'.join( curs.execute('CREATE TEMPORARY TABLE manycols (%s)' % ',\n'.join(
["%s int" % c for c in cols])) ["%s int" % c for c in cols]))
curs.execute("INSERT INTO manycols DEFAULT VALUES") curs.execute("INSERT INTO manycols DEFAULT VALUES")
@ -278,8 +278,7 @@ class CopyTests(ConnectingTestCase):
@skip_before_postgres(8, 2) # they don't send the count @skip_before_postgres(8, 2) # they don't send the count
def test_copy_rowcount(self): def test_copy_rowcount(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.copy_from(StringIO('aaa\nbbb\nccc\n'), 'tcopy', columns=['data']) curs.copy_from(StringIO('aaa\nbbb\nccc\n'), 'tcopy', columns=['data'])
self.assertEqual(curs.rowcount, 3) self.assertEqual(curs.rowcount, 3)
@ -296,8 +295,7 @@ class CopyTests(ConnectingTestCase):
self.assertEqual(curs.rowcount, 6) self.assertEqual(curs.rowcount, 6)
def test_copy_rowcount_error(self): def test_copy_rowcount_error(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("insert into tcopy (data) values ('fff')") curs.execute("insert into tcopy (data) values ('fff')")
self.assertEqual(curs.rowcount, 1) self.assertEqual(curs.rowcount, 1)
@ -317,6 +315,7 @@ try:
curs.execute("copy copy_segf from stdin") curs.execute("copy copy_segf from stdin")
except psycopg2.ProgrammingError: except psycopg2.ProgrammingError:
pass pass
curs.close()
conn.close() conn.close()
""" % {'dsn': dsn}) """ % {'dsn': dsn})
@ -336,6 +335,7 @@ try:
curs.execute("copy copy_segf to stdout") curs.execute("copy copy_segf to stdout")
except psycopg2.ProgrammingError: except psycopg2.ProgrammingError:
pass pass
curs.close()
conn.close() conn.close()
""" % {'dsn': dsn}) """ % {'dsn': dsn})
@ -351,7 +351,7 @@ conn.close()
def readline(self): def readline(self):
return 1 / 0 return 1 / 0
curs = self.conn.cursor() with self.conn.cursor() as curs:
# It seems we cannot do this, but now at least we propagate the error # It seems we cannot do this, but now at least we propagate the error
# self.assertRaises(ZeroDivisionError, # self.assertRaises(ZeroDivisionError,
# curs.copy_from, BrokenRead(), "tcopy") # curs.copy_from, BrokenRead(), "tcopy")
@ -365,7 +365,7 @@ conn.close()
def write(self, data): def write(self, data):
return 1 / 0 return 1 / 0
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("insert into tcopy values (10, 'hi')") curs.execute("insert into tcopy values (10, 'hi')")
self.assertRaises(ZeroDivisionError, self.assertRaises(ZeroDivisionError,
curs.copy_to, BrokenWrite(), "tcopy") curs.copy_to, BrokenWrite(), "tcopy")

View File

@ -51,7 +51,7 @@ class CursorTests(ConnectingTestCase):
self.assert_(cur.closed) self.assert_(cur.closed)
def test_empty_query(self): def test_empty_query(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.assertRaises(psycopg2.ProgrammingError, cur.execute, "") self.assertRaises(psycopg2.ProgrammingError, cur.execute, "")
self.assertRaises(psycopg2.ProgrammingError, cur.execute, " ") self.assertRaises(psycopg2.ProgrammingError, cur.execute, " ")
self.assertRaises(psycopg2.ProgrammingError, cur.execute, ";") self.assertRaises(psycopg2.ProgrammingError, cur.execute, ";")
@ -70,8 +70,7 @@ class CursorTests(ConnectingTestCase):
def test_mogrify_unicode(self): def test_mogrify_unicode(self):
conn = self.conn conn = self.conn
cur = conn.cursor() with conn.cursor() as cur:
# test consistency between execute and mogrify. # test consistency between execute and mogrify.
# unicode query containing only ascii data # unicode query containing only ascii data
@ -108,7 +107,7 @@ class CursorTests(ConnectingTestCase):
def test_mogrify_decimal_explodes(self): def test_mogrify_decimal_explodes(self):
conn = self.conn conn = self.conn
cur = conn.cursor() with conn.cursor() as cur:
self.assertEqual(b'SELECT 10.3;', self.assertEqual(b'SELECT 10.3;',
cur.mogrify("SELECT %s;", (Decimal("10.3"),))) cur.mogrify("SELECT %s;", (Decimal("10.3"),)))
@ -116,7 +115,7 @@ class CursorTests(ConnectingTestCase):
def test_mogrify_leak_on_multiple_reference(self): def test_mogrify_leak_on_multiple_reference(self):
# issue #81: reference leak when a parameter value is referenced # issue #81: reference leak when a parameter value is referenced
# more than once from a dict. # more than once from a dict.
cur = self.conn.cursor() with self.conn.cursor() as cur:
foo = (lambda x: x)('foo') * 10 foo = (lambda x: x)('foo') * 10
nref1 = sys.getrefcount(foo) nref1 = sys.getrefcount(foo)
cur.mogrify("select %(foo)s, %(foo)s, %(foo)s", {'foo': foo}) cur.mogrify("select %(foo)s, %(foo)s, %(foo)s", {'foo': foo})
@ -130,7 +129,7 @@ class CursorTests(ConnectingTestCase):
self.assertEqual(sql, b"select 10") self.assertEqual(sql, b"select 10")
def test_bad_placeholder(self): def test_bad_placeholder(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.assertRaises(psycopg2.ProgrammingError, self.assertRaises(psycopg2.ProgrammingError,
cur.mogrify, "select %(foo", {}) cur.mogrify, "select %(foo", {})
self.assertRaises(psycopg2.ProgrammingError, self.assertRaises(psycopg2.ProgrammingError,
@ -141,8 +140,7 @@ class CursorTests(ConnectingTestCase):
cur.mogrify, "select %(foo, %(bar)", {'foo': 1, 'bar': 2}) cur.mogrify, "select %(foo, %(bar)", {'foo': 1, 'bar': 2})
def test_cast(self): def test_cast(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
self.assertEqual(42, curs.cast(20, '42')) self.assertEqual(42, curs.cast(20, '42'))
self.assertAlmostEqual(3.14, curs.cast(700, '3.14')) self.assertAlmostEqual(3.14, curs.cast(700, '3.14'))
@ -152,7 +150,7 @@ class CursorTests(ConnectingTestCase):
self.assertEqual("who am i?", curs.cast(705, 'who am i?')) # unknown self.assertEqual("who am i?", curs.cast(705, 'who am i?')) # unknown
def test_cast_specificity(self): def test_cast_specificity(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
self.assertEqual("foo", curs.cast(705, 'foo')) self.assertEqual("foo", curs.cast(705, 'foo'))
D = psycopg2.extensions.new_type((705,), "DOUBLING", lambda v, c: v * 2) D = psycopg2.extensions.new_type((705,), "DOUBLING", lambda v, c: v * 2)
@ -163,18 +161,18 @@ class CursorTests(ConnectingTestCase):
psycopg2.extensions.register_type(T, curs) psycopg2.extensions.register_type(T, curs)
self.assertEqual("foofoofoo", curs.cast(705, 'foo')) self.assertEqual("foofoofoo", curs.cast(705, 'foo'))
curs2 = self.conn.cursor() with self.conn.cursor() as curs2:
self.assertEqual("foofoo", curs2.cast(705, 'foo')) self.assertEqual("foofoo", curs2.cast(705, 'foo'))
def test_weakref(self): def test_weakref(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
w = ref(curs) w = ref(curs)
del curs del curs
gc.collect() gc.collect()
self.assert_(w() is None) self.assert_(w() is None)
def test_null_name(self): def test_null_name(self):
curs = self.conn.cursor(None) with self.conn.cursor(None) as curs:
self.assertEqual(curs.name, None) self.assertEqual(curs.name, None)
def test_invalid_name(self): def test_invalid_name(self):
@ -184,7 +182,7 @@ class CursorTests(ConnectingTestCase):
curs.execute("insert into invname values (%s)", (i,)) curs.execute("insert into invname values (%s)", (i,))
curs.close() curs.close()
curs = self.conn.cursor(r'1-2-3 \ "test"') with self.conn.cursor(r'1-2-3 \ "test"') as curs:
curs.execute("select data from invname order by data") curs.execute("select data from invname order by data")
self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)]) self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)])
@ -213,13 +211,13 @@ class CursorTests(ConnectingTestCase):
self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)]) self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)])
curs.close() curs.close()
curs = self.conn.cursor("W", withhold=True) with self.conn.cursor("W", withhold=True) as curs:
self.assertEqual(curs.withhold, True) self.assertEqual(curs.withhold, True)
curs.execute("select data from withhold order by data") curs.execute("select data from withhold order by data")
self.conn.commit() self.conn.commit()
self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)]) self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)])
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("drop table withhold") curs.execute("drop table withhold")
self.conn.commit() self.conn.commit()
@ -328,7 +326,7 @@ class CursorTests(ConnectingTestCase):
return self.skipTest("can't evaluate non-scrollable cursor") return self.skipTest("can't evaluate non-scrollable cursor")
curs.close() curs.close()
curs = self.conn.cursor("S", scrollable=False) with self.conn.cursor("S", scrollable=False) as curs:
self.assertEqual(curs.scrollable, False) self.assertEqual(curs.scrollable, False)
curs.execute("select * from scrollable") curs.execute("select * from scrollable")
curs.scroll(2) curs.scroll(2)
@ -337,7 +335,7 @@ class CursorTests(ConnectingTestCase):
@slow @slow
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def test_iter_named_cursor_efficient(self): def test_iter_named_cursor_efficient(self):
curs = self.conn.cursor('tmp') with self.conn.cursor('tmp') as curs:
# if these records are fetched in the same roundtrip their # if these records are fetched in the same roundtrip their
# timestamp will not be influenced by the pause in Python world. # timestamp will not be influenced by the pause in Python world.
curs.execute("""select clock_timestamp() from generate_series(1,2)""") curs.execute("""select clock_timestamp() from generate_series(1,2)""")
@ -351,7 +349,7 @@ class CursorTests(ConnectingTestCase):
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def test_iter_named_cursor_default_itersize(self): def test_iter_named_cursor_default_itersize(self):
curs = self.conn.cursor('tmp') with self.conn.cursor('tmp') as curs:
curs.execute('select generate_series(1,50)') curs.execute('select generate_series(1,50)')
rv = [(r[0], curs.rownumber) for r in curs] rv = [(r[0], curs.rownumber) for r in curs]
# everything swallowed in one gulp # everything swallowed in one gulp
@ -359,7 +357,7 @@ class CursorTests(ConnectingTestCase):
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def test_iter_named_cursor_itersize(self): def test_iter_named_cursor_itersize(self):
curs = self.conn.cursor('tmp') with self.conn.cursor('tmp') as curs:
curs.itersize = 30 curs.itersize = 30
curs.execute('select generate_series(1,50)') curs.execute('select generate_series(1,50)')
rv = [(r[0], curs.rownumber) for r in curs] rv = [(r[0], curs.rownumber) for r in curs]
@ -368,7 +366,7 @@ class CursorTests(ConnectingTestCase):
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def test_iter_named_cursor_rownumber(self): def test_iter_named_cursor_rownumber(self):
curs = self.conn.cursor('tmp') with self.conn.cursor('tmp') as curs:
# note: this fails if itersize < dataset: internally we check # note: this fails if itersize < dataset: internally we check
# rownumber == rowcount to detect when to read anoter page, so we # rownumber == rowcount to detect when to read anoter page, so we
# would need an extra attribute to have a monotonic rownumber. # would need an extra attribute to have a monotonic rownumber.
@ -378,7 +376,7 @@ class CursorTests(ConnectingTestCase):
self.assertEqual(i + 1, curs.rownumber) self.assertEqual(i + 1, curs.rownumber)
def test_description_attribs(self): def test_description_attribs(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("""select curs.execute("""select
3.14::decimal(10,2) as pi, 3.14::decimal(10,2) as pi,
'hello'::text as hi, 'hello'::text as hi,
@ -413,7 +411,7 @@ class CursorTests(ConnectingTestCase):
self.assertEqual(c.scale, None) self.assertEqual(c.scale, None)
def test_description_extra_attribs(self): def test_description_extra_attribs(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute(""" curs.execute("""
create table testcol ( create table testcol (
pi decimal(10,2), pi decimal(10,2),
@ -434,7 +432,7 @@ class CursorTests(ConnectingTestCase):
self.assertEqual(curs.description[2].table_column, None) self.assertEqual(curs.description[2].table_column, None)
def test_pickle_description(self): def test_pickle_description(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute('SELECT 1 AS foo') curs.execute('SELECT 1 AS foo')
description = curs.description description = curs.description
@ -447,11 +445,11 @@ class CursorTests(ConnectingTestCase):
def test_named_cursor_stealing(self): def test_named_cursor_stealing(self):
# you can use a named cursor to iterate on a refcursor created # you can use a named cursor to iterate on a refcursor created
# somewhere else # somewhere else
cur1 = self.conn.cursor() with self.conn.cursor() as cur1:
cur1.execute("DECLARE test CURSOR WITHOUT HOLD " cur1.execute("DECLARE test CURSOR WITHOUT HOLD "
" FOR SELECT generate_series(1,7)") " FOR SELECT generate_series(1,7)")
cur2 = self.conn.cursor('test') with self.conn.cursor('test') as cur2:
# can call fetch without execute # can call fetch without execute
self.assertEqual((1,), cur2.fetchone()) self.assertEqual((1,), cur2.fetchone())
self.assertEqual([(2,), (3,), (4,)], cur2.fetchmany(3)) self.assertEqual([(2,), (3,), (4,)], cur2.fetchmany(3))
@ -464,7 +462,7 @@ class CursorTests(ConnectingTestCase):
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def test_stolen_named_cursor_close(self): def test_stolen_named_cursor_close(self):
cur1 = self.conn.cursor() with self.conn.cursor() as cur1:
cur1.execute("DECLARE test CURSOR WITHOUT HOLD " cur1.execute("DECLARE test CURSOR WITHOUT HOLD "
" FOR SELECT generate_series(1,7)") " FOR SELECT generate_series(1,7)")
cur2 = self.conn.cursor('test') cur2 = self.conn.cursor('test')
@ -477,7 +475,7 @@ class CursorTests(ConnectingTestCase):
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def test_scroll(self): def test_scroll(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select generate_series(0,9)") cur.execute("select generate_series(0,9)")
cur.scroll(2) cur.scroll(2)
self.assertEqual(cur.fetchone(), (2,)) self.assertEqual(cur.fetchone(), (2,))
@ -511,7 +509,7 @@ class CursorTests(ConnectingTestCase):
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def test_scroll_named(self): def test_scroll_named(self):
cur = self.conn.cursor('tmp', scrollable=True) with self.conn.cursor('tmp', scrollable=True) as cur:
cur.execute("select generate_series(0,9)") cur.execute("select generate_series(0,9)")
cur.scroll(2) cur.scroll(2)
self.assertEqual(cur.fetchone(), (2,)) self.assertEqual(cur.fetchone(), (2,))
@ -531,13 +529,13 @@ class CursorTests(ConnectingTestCase):
# I am stupid so not calling superclass init # I am stupid so not calling superclass init
pass pass
cur = StupidCursor() with StupidCursor() as cur:
self.assertRaises(psycopg2.InterfaceError, cur.execute, 'select 1') self.assertRaises(psycopg2.InterfaceError, cur.execute, 'select 1')
self.assertRaises(psycopg2.InterfaceError, cur.executemany, self.assertRaises(psycopg2.InterfaceError, cur.executemany,
'select 1', []) 'select 1', [])
def test_callproc_badparam(self): def test_callproc_badparam(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.assertRaises(TypeError, cur.callproc, 'lower', 42) self.assertRaises(TypeError, cur.callproc, 'lower', 42)
# It would be inappropriate to test callproc's named parameters in the # It would be inappropriate to test callproc's named parameters in the
@ -551,8 +549,7 @@ class CursorTests(ConnectingTestCase):
escaped_paramname = '"%s"' % paramname.replace('"', '""') escaped_paramname = '"%s"' % paramname.replace('"', '""')
procname = 'pg_temp.randall' procname = 'pg_temp.randall'
cur = self.conn.cursor() with self.conn.cursor() as cur:
# Set up the temporary function # Set up the temporary function
cur.execute(''' cur.execute('''
CREATE FUNCTION %s(%s INT) CREATE FUNCTION %s(%s INT)
@ -638,7 +635,7 @@ class CursorTests(ConnectingTestCase):
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def test_rowcount_on_executemany_returning(self): def test_rowcount_on_executemany_returning(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("create table execmany(id serial primary key, data int)") cur.execute("create table execmany(id serial primary key, data int)")
cur.executemany( cur.executemany(
"insert into execmany (data) values (%s)", "insert into execmany (data) values (%s)",

View File

@ -369,7 +369,7 @@ class DatetimeTests(ConnectingTestCase, CommonDatetimeTestsMixin):
self.assertEqual(total_seconds(t), 1e-6) self.assertEqual(total_seconds(t), 1e-6)
def test_interval_overflow(self): def test_interval_overflow(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
# hack a cursor to receive values too extreme to be represented # hack a cursor to receive values too extreme to be represented
# but still I want an error, not a random number # but still I want an error, not a random number
psycopg2.extensions.register_type( psycopg2.extensions.register_type(
@ -405,7 +405,7 @@ class DatetimeTests(ConnectingTestCase, CommonDatetimeTestsMixin):
def test_redshift_day(self): def test_redshift_day(self):
# Redshift is reported returning 1 day interval as microsec (bug #558) # Redshift is reported returning 1 day interval as microsec (bug #558)
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extensions.register_type( psycopg2.extensions.register_type(
psycopg2.extensions.new_type( psycopg2.extensions.new_type(
psycopg2.STRING.values, 'WAT', psycopg2.extensions.INTERVAL), psycopg2.STRING.values, 'WAT', psycopg2.extensions.INTERVAL),
@ -426,7 +426,7 @@ class DatetimeTests(ConnectingTestCase, CommonDatetimeTestsMixin):
@skip_before_postgres(8, 4) @skip_before_postgres(8, 4)
def test_interval_iso_8601_not_supported(self): def test_interval_iso_8601_not_supported(self):
# We may end up supporting, but no pressure for it # We may end up supporting, but no pressure for it
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("set local intervalstyle to iso_8601") cur.execute("set local intervalstyle to iso_8601")
cur.execute("select '1 day 2 hours'::interval") cur.execute("select '1 day 2 hours'::interval")
self.assertRaises(psycopg2.NotSupportedError, cur.fetchone) self.assertRaises(psycopg2.NotSupportedError, cur.fetchone)

View File

@ -32,7 +32,7 @@ from .testutils import ConnectingTestCase, skip_before_postgres, \
class _DictCursorBase(ConnectingTestCase): class _DictCursorBase(ConnectingTestCase):
def setUp(self): def setUp(self):
ConnectingTestCase.setUp(self) ConnectingTestCase.setUp(self)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("CREATE TEMPORARY TABLE ExtrasDictCursorTests (foo text)") curs.execute("CREATE TEMPORARY TABLE ExtrasDictCursorTests (foo text)")
curs.execute("INSERT INTO ExtrasDictCursorTests VALUES ('bar')") curs.execute("INSERT INTO ExtrasDictCursorTests VALUES ('bar')")
self.conn.commit() self.conn.commit()
@ -61,15 +61,18 @@ class _DictCursorBase(ConnectingTestCase):
class ExtrasDictCursorTests(_DictCursorBase): class ExtrasDictCursorTests(_DictCursorBase):
"""Test if DictCursor extension class works.""" """Test if DictCursor extension class works."""
@skip_before_postgres(8, 2)
def testDictConnCursorArgs(self): def testDictConnCursorArgs(self):
self.conn.close() self.conn.close()
self.conn = self.connect(connection_factory=psycopg2.extras.DictConnection) self.conn = self.connect(connection_factory=psycopg2.extras.DictConnection)
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.assert_(isinstance(cur, psycopg2.extras.DictCursor)) self.assert_(isinstance(cur, psycopg2.extras.DictCursor))
self.assertEqual(cur.name, None) self.assertEqual(cur.name, None)
# overridable # overridable
cur = self.conn.cursor('foo', with self.conn.cursor(
cursor_factory=psycopg2.extras.NamedTupleCursor) 'foo',
cursor_factory=psycopg2.extras.NamedTupleCursor
) as cur:
self.assertEqual(cur.name, 'foo') self.assertEqual(cur.name, 'foo')
self.assert_(isinstance(cur, psycopg2.extras.NamedTupleCursor)) self.assert_(isinstance(cur, psycopg2.extras.NamedTupleCursor))
@ -99,11 +102,11 @@ class ExtrasDictCursorTests(_DictCursorBase):
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def testDictCursorWithPlainCursorIterRowNumber(self): def testDictCursorWithPlainCursorIterRowNumber(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs:
self._testIterRowNumber(curs) self._testIterRowNumber(curs)
def _testWithPlainCursor(self, getter): def _testWithPlainCursor(self, getter):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs:
curs.execute("SELECT * FROM ExtrasDictCursorTests") curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = getter(curs) row = getter(curs)
self.failUnless(row['foo'] == 'bar') self.failUnless(row['foo'] == 'bar')
@ -130,23 +133,32 @@ class ExtrasDictCursorTests(_DictCursorBase):
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def testDictCursorWithNamedCursorNotGreedy(self): def testDictCursorWithNamedCursorNotGreedy(self):
curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.DictCursor) with self.conn.cursor(
'tmp',
cursor_factory=psycopg2.extras.DictCursor
) as curs:
self._testNamedCursorNotGreedy(curs) self._testNamedCursorNotGreedy(curs)
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def testDictCursorWithNamedCursorIterRowNumber(self): def testDictCursorWithNamedCursorIterRowNumber(self):
curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.DictCursor) with self.conn.cursor(
'tmp',
cursor_factory=psycopg2.extras.DictCursor
) as curs:
self._testIterRowNumber(curs) self._testIterRowNumber(curs)
def _testWithNamedCursor(self, getter): def _testWithNamedCursor(self, getter):
curs = self.conn.cursor('aname', cursor_factory=psycopg2.extras.DictCursor) with self.conn.cursor(
'aname',
cursor_factory=psycopg2.extras.DictCursor
) as curs:
curs.execute("SELECT * FROM ExtrasDictCursorTests") curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = getter(curs) row = getter(curs)
self.failUnless(row['foo'] == 'bar') self.failUnless(row['foo'] == 'bar')
self.failUnless(row[0] == 'bar') self.failUnless(row[0] == 'bar')
def testPickleDictRow(self): def testPickleDictRow(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs:
curs.execute("select 10 as a, 20 as b") curs.execute("select 10 as a, 20 as b")
r = curs.fetchone() r = curs.fetchone()
d = pickle.dumps(r) d = pickle.dumps(r)
@ -160,7 +172,7 @@ class ExtrasDictCursorTests(_DictCursorBase):
@skip_from_python(3) @skip_from_python(3)
def test_iter_methods_2(self): def test_iter_methods_2(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs:
curs.execute("select 10 as a, 20 as b") curs.execute("select 10 as a, 20 as b")
r = curs.fetchone() r = curs.fetchone()
self.assert_(isinstance(r.keys(), list)) self.assert_(isinstance(r.keys(), list))
@ -179,7 +191,7 @@ class ExtrasDictCursorTests(_DictCursorBase):
@skip_before_python(3) @skip_before_python(3)
def test_iter_methods_3(self): def test_iter_methods_3(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs:
curs.execute("select 10 as a, 20 as b") curs.execute("select 10 as a, 20 as b")
r = curs.fetchone() r = curs.fetchone()
self.assert_(not isinstance(r.keys(), list)) self.assert_(not isinstance(r.keys(), list))
@ -190,7 +202,7 @@ class ExtrasDictCursorTests(_DictCursorBase):
self.assertEqual(len(list(r.items())), 2) self.assertEqual(len(list(r.items())), 2)
def test_order(self): def test_order(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs:
curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux") curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux")
r = curs.fetchone() r = curs.fetchone()
self.assertEqual(list(r), [5, 4, 33, 2]) self.assertEqual(list(r), [5, 4, 33, 2])
@ -223,7 +235,7 @@ class ExtrasDictCursorTests(_DictCursorBase):
class ExtrasDictCursorRealTests(_DictCursorBase): class ExtrasDictCursorRealTests(_DictCursorBase):
def testRealMeansReal(self): def testRealMeansReal(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs:
curs.execute("SELECT * FROM ExtrasDictCursorTests") curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = curs.fetchone() row = curs.fetchone()
self.assert_(isinstance(row, dict)) self.assert_(isinstance(row, dict))
@ -248,17 +260,17 @@ class ExtrasDictCursorRealTests(_DictCursorBase):
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def testDictCursorWithPlainCursorRealIterRowNumber(self): def testDictCursorWithPlainCursorRealIterRowNumber(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs:
self._testIterRowNumber(curs) self._testIterRowNumber(curs)
def _testWithPlainCursorReal(self, getter): def _testWithPlainCursorReal(self, getter):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs:
curs.execute("SELECT * FROM ExtrasDictCursorTests") curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = getter(curs) row = getter(curs)
self.failUnless(row['foo'] == 'bar') self.failUnless(row['foo'] == 'bar')
def testPickleRealDictRow(self): def testPickleRealDictRow(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs:
curs.execute("select 10 as a, 20 as b") curs.execute("select 10 as a, 20 as b")
r = curs.fetchone() r = curs.fetchone()
d = pickle.dumps(r) d = pickle.dumps(r)
@ -287,24 +299,32 @@ class ExtrasDictCursorRealTests(_DictCursorBase):
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def testDictCursorRealWithNamedCursorNotGreedy(self): def testDictCursorRealWithNamedCursorNotGreedy(self):
curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(
'tmp',
cursor_factory=psycopg2.extras.RealDictCursor
) as curs:
self._testNamedCursorNotGreedy(curs) self._testNamedCursorNotGreedy(curs)
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def testDictCursorRealWithNamedCursorIterRowNumber(self): def testDictCursorRealWithNamedCursorIterRowNumber(self):
curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(
'tmp',
cursor_factory=psycopg2.extras.RealDictCursor
) as curs:
self._testIterRowNumber(curs) self._testIterRowNumber(curs)
def _testWithNamedCursorReal(self, getter): def _testWithNamedCursorReal(self, getter):
curs = self.conn.cursor('aname', with self.conn.cursor(
cursor_factory=psycopg2.extras.RealDictCursor) 'aname',
cursor_factory=psycopg2.extras.RealDictCursor
) as curs:
curs.execute("SELECT * FROM ExtrasDictCursorTests") curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = getter(curs) row = getter(curs)
self.failUnless(row['foo'] == 'bar') self.failUnless(row['foo'] == 'bar')
@skip_from_python(3) @skip_from_python(3)
def test_iter_methods_2(self): def test_iter_methods_2(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs:
curs.execute("select 10 as a, 20 as b") curs.execute("select 10 as a, 20 as b")
r = curs.fetchone() r = curs.fetchone()
self.assert_(isinstance(r.keys(), list)) self.assert_(isinstance(r.keys(), list))
@ -323,7 +343,7 @@ class ExtrasDictCursorRealTests(_DictCursorBase):
@skip_before_python(3) @skip_before_python(3)
def test_iter_methods_3(self): def test_iter_methods_3(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs:
curs.execute("select 10 as a, 20 as b") curs.execute("select 10 as a, 20 as b")
r = curs.fetchone() r = curs.fetchone()
self.assert_(not isinstance(r.keys(), list)) self.assert_(not isinstance(r.keys(), list))
@ -334,7 +354,7 @@ class ExtrasDictCursorRealTests(_DictCursorBase):
self.assertEqual(len(list(r.items())), 2) self.assertEqual(len(list(r.items())), 2)
def test_order(self): def test_order(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs:
curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux") curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux")
r = curs.fetchone() r = curs.fetchone()
self.assertEqual(list(r), ['foo', 'bar', 'baz', 'qux']) self.assertEqual(list(r), ['foo', 'bar', 'baz', 'qux'])
@ -351,7 +371,7 @@ class ExtrasDictCursorRealTests(_DictCursorBase):
@skip_from_python(3) @skip_from_python(3)
def test_order_iter(self): def test_order_iter(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs:
curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux") curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux")
r = curs.fetchone() r = curs.fetchone()
self.assertEqual(list(r.iterkeys()), ['foo', 'bar', 'baz', 'qux']) self.assertEqual(list(r.iterkeys()), ['foo', 'bar', 'baz', 'qux'])
@ -365,7 +385,7 @@ class ExtrasDictCursorRealTests(_DictCursorBase):
self.assertEqual(list(r1.iteritems()), list(r.iteritems())) self.assertEqual(list(r1.iteritems()), list(r.iteritems()))
def test_pop(self): def test_pop(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs:
curs.execute("select 1 as a, 2 as b, 3 as c") curs.execute("select 1 as a, 2 as b, 3 as c")
r = curs.fetchone() r = curs.fetchone()
self.assertEqual(r.pop('b'), 2) self.assertEqual(r.pop('b'), 2)
@ -378,7 +398,7 @@ class ExtrasDictCursorRealTests(_DictCursorBase):
self.assertRaises(KeyError, r.pop, 'b') self.assertRaises(KeyError, r.pop, 'b')
def test_mod(self): def test_mod(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs:
curs.execute("select 1 as a, 2 as b, 3 as c") curs.execute("select 1 as a, 2 as b, 3 as c")
r = curs.fetchone() r = curs.fetchone()
r['d'] = 4 r['d'] = 4
@ -399,20 +419,24 @@ class NamedTupleCursorTest(ConnectingTestCase):
ConnectingTestCase.setUp(self) ConnectingTestCase.setUp(self)
self.conn = self.connect(connection_factory=NamedTupleConnection) self.conn = self.connect(connection_factory=NamedTupleConnection)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("CREATE TEMPORARY TABLE nttest (i int, s text)") curs.execute("CREATE TEMPORARY TABLE nttest (i int, s text)")
curs.execute("INSERT INTO nttest VALUES (1, 'foo')") curs.execute("INSERT INTO nttest VALUES (1, 'foo')")
curs.execute("INSERT INTO nttest VALUES (2, 'bar')") curs.execute("INSERT INTO nttest VALUES (2, 'bar')")
curs.execute("INSERT INTO nttest VALUES (3, 'baz')") curs.execute("INSERT INTO nttest VALUES (3, 'baz')")
self.conn.commit() self.conn.commit()
@skip_before_postgres(8, 2)
def test_cursor_args(self): def test_cursor_args(self):
cur = self.conn.cursor('foo', cursor_factory=psycopg2.extras.DictCursor) with self.conn.cursor(
'foo',
cursor_factory=psycopg2.extras.DictCursor
) as cur:
self.assertEqual(cur.name, 'foo') self.assertEqual(cur.name, 'foo')
self.assert_(isinstance(cur, psycopg2.extras.DictCursor)) self.assert_(isinstance(cur, psycopg2.extras.DictCursor))
def test_fetchone(self): def test_fetchone(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from nttest order by 1") curs.execute("select * from nttest order by 1")
t = curs.fetchone() t = curs.fetchone()
self.assertEqual(t[0], 1) self.assertEqual(t[0], 1)
@ -423,7 +447,7 @@ class NamedTupleCursorTest(ConnectingTestCase):
self.assertEqual(curs.rowcount, 3) self.assertEqual(curs.rowcount, 3)
def test_fetchmany_noarg(self): def test_fetchmany_noarg(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.arraysize = 2 curs.arraysize = 2
curs.execute("select * from nttest order by 1") curs.execute("select * from nttest order by 1")
res = curs.fetchmany() res = curs.fetchmany()
@ -436,7 +460,7 @@ class NamedTupleCursorTest(ConnectingTestCase):
self.assertEqual(curs.rowcount, 3) self.assertEqual(curs.rowcount, 3)
def test_fetchmany(self): def test_fetchmany(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from nttest order by 1") curs.execute("select * from nttest order by 1")
res = curs.fetchmany(2) res = curs.fetchmany(2)
self.assertEqual(2, len(res)) self.assertEqual(2, len(res))
@ -448,7 +472,7 @@ class NamedTupleCursorTest(ConnectingTestCase):
self.assertEqual(curs.rowcount, 3) self.assertEqual(curs.rowcount, 3)
def test_fetchall(self): def test_fetchall(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from nttest order by 1") curs.execute("select * from nttest order by 1")
res = curs.fetchall() res = curs.fetchall()
self.assertEqual(3, len(res)) self.assertEqual(3, len(res))
@ -462,7 +486,7 @@ class NamedTupleCursorTest(ConnectingTestCase):
self.assertEqual(curs.rowcount, 3) self.assertEqual(curs.rowcount, 3)
def test_executemany(self): def test_executemany(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.executemany("delete from nttest where i = %s", curs.executemany("delete from nttest where i = %s",
[(1,), (2,)]) [(1,), (2,)])
curs.execute("select * from nttest order by 1") curs.execute("select * from nttest order by 1")
@ -472,7 +496,7 @@ class NamedTupleCursorTest(ConnectingTestCase):
self.assertEqual(res[0].s, 'baz') self.assertEqual(res[0].s, 'baz')
def test_iter(self): def test_iter(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from nttest order by 1") curs.execute("select * from nttest order by 1")
i = iter(curs) i = iter(curs)
self.assertEqual(curs.rownumber, 0) self.assertEqual(curs.rownumber, 0)
@ -497,7 +521,7 @@ class NamedTupleCursorTest(ConnectingTestCase):
self.assertEqual(curs.rowcount, 3) self.assertEqual(curs.rowcount, 3)
def test_record_updated(self): def test_record_updated(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select 1 as foo;") curs.execute("select 1 as foo;")
r = curs.fetchone() r = curs.fetchone()
self.assertEqual(r.foo, 1) self.assertEqual(r.foo, 1)
@ -508,7 +532,7 @@ class NamedTupleCursorTest(ConnectingTestCase):
self.assertRaises(AttributeError, getattr, r, 'foo') self.assertRaises(AttributeError, getattr, r, 'foo')
def test_no_result_no_surprise(self): def test_no_result_no_surprise(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("update nttest set s = s") curs.execute("update nttest set s = s")
self.assertRaises(psycopg2.ProgrammingError, curs.fetchone) self.assertRaises(psycopg2.ProgrammingError, curs.fetchone)
@ -516,7 +540,7 @@ class NamedTupleCursorTest(ConnectingTestCase):
self.assertRaises(psycopg2.ProgrammingError, curs.fetchall) self.assertRaises(psycopg2.ProgrammingError, curs.fetchall)
def test_bad_col_names(self): def test_bad_col_names(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute('select 1 as "foo.bar_baz", 2 as "?column?", 3 as "3"') curs.execute('select 1 as "foo.bar_baz", 2 as "?column?", 3 as "3"')
rv = curs.fetchone() rv = curs.fetchone()
self.assertEqual(rv.foo_bar_baz, 1) self.assertEqual(rv.foo_bar_baz, 1)
@ -526,7 +550,7 @@ class NamedTupleCursorTest(ConnectingTestCase):
@skip_before_python(3) @skip_before_python(3)
@skip_before_postgres(8) @skip_before_postgres(8)
def test_nonascii_name(self): def test_nonascii_name(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute('select 1 as \xe5h\xe9') curs.execute('select 1 as \xe5h\xe9')
rv = curs.fetchone() rv = curs.fetchone()
self.assertEqual(getattr(rv, '\xe5h\xe9'), 1) self.assertEqual(getattr(rv, '\xe5h\xe9'), 1)
@ -543,7 +567,7 @@ class NamedTupleCursorTest(ConnectingTestCase):
NamedTupleCursor._make_nt = f_patched NamedTupleCursor._make_nt = f_patched
try: try:
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from nttest order by 1") curs.execute("select * from nttest order by 1")
curs.fetchone() curs.fetchone()
curs.fetchone() curs.fetchone()
@ -565,7 +589,7 @@ class NamedTupleCursorTest(ConnectingTestCase):
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def test_named(self): def test_named(self):
curs = self.conn.cursor('tmp') with self.conn.cursor('tmp') as curs:
curs.execute("""select i from generate_series(0,9) i""") curs.execute("""select i from generate_series(0,9) i""")
recs = [] recs = []
recs.extend(curs.fetchmany(5)) recs.extend(curs.fetchmany(5))
@ -574,28 +598,29 @@ class NamedTupleCursorTest(ConnectingTestCase):
self.assertEqual(list(range(10)), [t.i for t in recs]) self.assertEqual(list(range(10)), [t.i for t in recs])
def test_named_fetchone(self): def test_named_fetchone(self):
curs = self.conn.cursor('tmp') with self.conn.cursor('tmp') as curs:
curs.execute("""select 42 as i""") curs.execute("""select 42 as i""")
t = curs.fetchone() t = curs.fetchone()
self.assertEqual(t.i, 42) self.assertEqual(t.i, 42)
def test_named_fetchmany(self): def test_named_fetchmany(self):
curs = self.conn.cursor('tmp') with self.conn.cursor('tmp') as curs:
curs.execute("""select 42 as i""") curs.execute("""select 42 as i""")
recs = curs.fetchmany(10) recs = curs.fetchmany(10)
self.assertEqual(recs[0].i, 42) self.assertEqual(recs[0].i, 42)
def test_named_fetchall(self): def test_named_fetchall(self):
curs = self.conn.cursor('tmp') with self.conn.cursor('tmp') as curs:
curs.execute("""select 42 as i""") curs.execute("""select 42 as i""")
recs = curs.fetchall() recs = curs.fetchall()
self.assertEqual(recs[0].i, 42) self.assertEqual(recs[0].i, 42)
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def test_not_greedy(self): def test_not_greedy(self):
curs = self.conn.cursor('tmp') with self.conn.cursor('tmp') as curs:
curs.itersize = 2 curs.itersize = 2
curs.execute("""select clock_timestamp() as ts from generate_series(1,3)""") curs.execute(
"""select clock_timestamp() as ts from generate_series(1,3)""")
recs = [] recs = []
for t in curs: for t in curs:
time.sleep(0.01) time.sleep(0.01)
@ -607,7 +632,7 @@ class NamedTupleCursorTest(ConnectingTestCase):
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def test_named_rownumber(self): def test_named_rownumber(self):
curs = self.conn.cursor('tmp') with self.conn.cursor('tmp') as curs:
# Only checking for dataset < itersize: # Only checking for dataset < itersize:
# see CursorTests.test_iter_named_cursor_rownumber # see CursorTests.test_iter_named_cursor_rownumber
curs.itersize = 4 curs.itersize = 4
@ -618,14 +643,14 @@ class NamedTupleCursorTest(ConnectingTestCase):
def test_cache(self): def test_cache(self):
NamedTupleCursor._cached_make_nt.cache_clear() NamedTupleCursor._cached_make_nt.cache_clear()
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select 10 as a, 20 as b") curs.execute("select 10 as a, 20 as b")
r1 = curs.fetchone() r1 = curs.fetchone()
curs.execute("select 10 as a, 20 as c") curs.execute("select 10 as a, 20 as c")
r2 = curs.fetchone() r2 = curs.fetchone()
# Get a new cursor to check that the cache works across multiple ones # Get a new cursor to check that the cache works across multiple ones
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select 10 as a, 30 as b") curs.execute("select 10 as a, 30 as b")
r3 = curs.fetchone() r3 = curs.fetchone()
@ -643,7 +668,7 @@ class NamedTupleCursorTest(ConnectingTestCase):
lru_cache(8)(NamedTupleCursor._cached_make_nt.__wrapped__) lru_cache(8)(NamedTupleCursor._cached_make_nt.__wrapped__)
try: try:
recs = [] recs = []
curs = self.conn.cursor() with self.conn.cursor() as curs:
for i in range(10): for i in range(10):
curs.execute("select 1 as f%s" % i) curs.execute("select 1 as f%s" % i)
recs.append(curs.fetchone()) recs.append(curs.fetchone())

View File

@ -46,14 +46,14 @@ class TestPaginate(unittest.TestCase):
class FastExecuteTestMixin(object): class FastExecuteTestMixin(object):
def setUp(self): def setUp(self):
super(FastExecuteTestMixin, self).setUp() super(FastExecuteTestMixin, self).setUp()
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("""create table testfast ( cur.execute("""create table testfast (
id serial primary key, date date, val int, data text)""") id serial primary key, date date, val int, data text)""")
class TestExecuteBatch(FastExecuteTestMixin, testutils.ConnectingTestCase): class TestExecuteBatch(FastExecuteTestMixin, testutils.ConnectingTestCase):
def test_empty(self): def test_empty(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_batch(cur, psycopg2.extras.execute_batch(cur,
"insert into testfast (id, val) values (%s, %s)", "insert into testfast (id, val) values (%s, %s)",
[]) [])
@ -61,7 +61,7 @@ class TestExecuteBatch(FastExecuteTestMixin, testutils.ConnectingTestCase):
self.assertEqual(cur.fetchall(), []) self.assertEqual(cur.fetchall(), [])
def test_one(self): def test_one(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_batch(cur, psycopg2.extras.execute_batch(cur,
"insert into testfast (id, val) values (%s, %s)", "insert into testfast (id, val) values (%s, %s)",
iter([(1, 10)])) iter([(1, 10)]))
@ -69,7 +69,7 @@ class TestExecuteBatch(FastExecuteTestMixin, testutils.ConnectingTestCase):
self.assertEqual(cur.fetchall(), [(1, 10)]) self.assertEqual(cur.fetchall(), [(1, 10)])
def test_tuples(self): def test_tuples(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_batch(cur, psycopg2.extras.execute_batch(cur,
"insert into testfast (id, date, val) values (%s, %s, %s)", "insert into testfast (id, date, val) values (%s, %s, %s)",
((i, date(2017, 1, i + 1), i * 10) for i in range(10))) ((i, date(2017, 1, i + 1), i * 10) for i in range(10)))
@ -78,7 +78,7 @@ class TestExecuteBatch(FastExecuteTestMixin, testutils.ConnectingTestCase):
[(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) [(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
def test_many(self): def test_many(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_batch(cur, psycopg2.extras.execute_batch(cur,
"insert into testfast (id, val) values (%s, %s)", "insert into testfast (id, val) values (%s, %s)",
((i, i * 10) for i in range(1000))) ((i, i * 10) for i in range(1000)))
@ -86,7 +86,7 @@ class TestExecuteBatch(FastExecuteTestMixin, testutils.ConnectingTestCase):
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
def test_composed(self): def test_composed(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_batch(cur, psycopg2.extras.execute_batch(cur,
sql.SQL("insert into {0} (id, val) values (%s, %s)") sql.SQL("insert into {0} (id, val) values (%s, %s)")
.format(sql.Identifier('testfast')), .format(sql.Identifier('testfast')),
@ -95,7 +95,7 @@ class TestExecuteBatch(FastExecuteTestMixin, testutils.ConnectingTestCase):
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
def test_pages(self): def test_pages(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_batch(cur, psycopg2.extras.execute_batch(cur,
"insert into testfast (id, val) values (%s, %s)", "insert into testfast (id, val) values (%s, %s)",
((i, i * 10) for i in range(25)), ((i, i * 10) for i in range(25)),
@ -109,7 +109,7 @@ class TestExecuteBatch(FastExecuteTestMixin, testutils.ConnectingTestCase):
@testutils.skip_before_postgres(8, 0) @testutils.skip_before_postgres(8, 0)
def test_unicode(self): def test_unicode(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
ext.register_type(ext.UNICODE, cur) ext.register_type(ext.UNICODE, cur)
snowman = u"\u2603" snowman = u"\u2603"
@ -138,7 +138,7 @@ class TestExecuteBatch(FastExecuteTestMixin, testutils.ConnectingTestCase):
@testutils.skip_before_postgres(8, 2) @testutils.skip_before_postgres(8, 2)
class TestExecuteValues(FastExecuteTestMixin, testutils.ConnectingTestCase): class TestExecuteValues(FastExecuteTestMixin, testutils.ConnectingTestCase):
def test_empty(self): def test_empty(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_values(cur, psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s", "insert into testfast (id, val) values %s",
[]) [])
@ -146,7 +146,7 @@ class TestExecuteValues(FastExecuteTestMixin, testutils.ConnectingTestCase):
self.assertEqual(cur.fetchall(), []) self.assertEqual(cur.fetchall(), [])
def test_one(self): def test_one(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_values(cur, psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s", "insert into testfast (id, val) values %s",
iter([(1, 10)])) iter([(1, 10)]))
@ -154,7 +154,7 @@ class TestExecuteValues(FastExecuteTestMixin, testutils.ConnectingTestCase):
self.assertEqual(cur.fetchall(), [(1, 10)]) self.assertEqual(cur.fetchall(), [(1, 10)])
def test_tuples(self): def test_tuples(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_values(cur, psycopg2.extras.execute_values(cur,
"insert into testfast (id, date, val) values %s", "insert into testfast (id, date, val) values %s",
((i, date(2017, 1, i + 1), i * 10) for i in range(10))) ((i, date(2017, 1, i + 1), i * 10) for i in range(10)))
@ -163,7 +163,7 @@ class TestExecuteValues(FastExecuteTestMixin, testutils.ConnectingTestCase):
[(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) [(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
def test_dicts(self): def test_dicts(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_values(cur, psycopg2.extras.execute_values(cur,
"insert into testfast (id, date, val) values %s", "insert into testfast (id, date, val) values %s",
(dict(id=i, date=date(2017, 1, i + 1), val=i * 10, foo="bar") (dict(id=i, date=date(2017, 1, i + 1), val=i * 10, foo="bar")
@ -174,7 +174,7 @@ class TestExecuteValues(FastExecuteTestMixin, testutils.ConnectingTestCase):
[(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) [(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
def test_many(self): def test_many(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_values(cur, psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s", "insert into testfast (id, val) values %s",
((i, i * 10) for i in range(1000))) ((i, i * 10) for i in range(1000)))
@ -182,7 +182,7 @@ class TestExecuteValues(FastExecuteTestMixin, testutils.ConnectingTestCase):
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
def test_composed(self): def test_composed(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_values(cur, psycopg2.extras.execute_values(cur,
sql.SQL("insert into {0} (id, val) values %s") sql.SQL("insert into {0} (id, val) values %s")
.format(sql.Identifier('testfast')), .format(sql.Identifier('testfast')),
@ -191,7 +191,7 @@ class TestExecuteValues(FastExecuteTestMixin, testutils.ConnectingTestCase):
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
def test_pages(self): def test_pages(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_values(cur, psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s", "insert into testfast (id, val) values %s",
((i, i * 10) for i in range(25)), ((i, i * 10) for i in range(25)),
@ -204,7 +204,7 @@ class TestExecuteValues(FastExecuteTestMixin, testutils.ConnectingTestCase):
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)]) self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)])
def test_unicode(self): def test_unicode(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
ext.register_type(ext.UNICODE, cur) ext.register_type(ext.UNICODE, cur)
snowman = u"\u2603" snowman = u"\u2603"
@ -230,7 +230,7 @@ class TestExecuteValues(FastExecuteTestMixin, testutils.ConnectingTestCase):
self.assertEqual(cur.fetchone(), (3, snowman)) self.assertEqual(cur.fetchone(), (3, snowman))
def test_returning(self): def test_returning(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
result = psycopg2.extras.execute_values(cur, result = psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s returning id", "insert into testfast (id, val) values %s returning id",
((i, i * 10) for i in range(25)), ((i, i * 10) for i in range(25)),
@ -239,7 +239,7 @@ class TestExecuteValues(FastExecuteTestMixin, testutils.ConnectingTestCase):
self.assertEqual([r[0] for r in result], list(range(25))) self.assertEqual([r[0] for r in result], list(range(25)))
def test_invalid_sql(self): def test_invalid_sql(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
"insert", []) "insert", [])
self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
@ -250,7 +250,7 @@ class TestExecuteValues(FastExecuteTestMixin, testutils.ConnectingTestCase):
"insert %f %s", []) "insert %f %s", [])
def test_percent_escape(self): def test_percent_escape(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_values(cur, psycopg2.extras.execute_values(cur,
"insert into testfast (id, data) values %s -- a%%b", "insert into testfast (id, data) values %s -- a%%b",
[(1, 'hi')]) [(1, 'hi')])

View File

@ -71,7 +71,7 @@ class GreenTestCase(ConnectingTestCase):
# a very large query requires a flush loop to be sent to the backend # a very large query requires a flush loop to be sent to the backend
conn = self.conn conn = self.conn
stub = self.set_stub_wait_callback(conn) stub = self.set_stub_wait_callback(conn)
curs = conn.cursor() with conn.cursor() as curs:
for mb in 1, 5, 10, 20, 50: for mb in 1, 5, 10, 20, 50:
size = mb * 1024 * 1024 size = mb * 1024 * 1024
del stub.polls[:] del stub.polls[:]
@ -105,7 +105,7 @@ class GreenTestCase(ConnectingTestCase):
# if there is an error in a green query, don't freak out and close # if there is an error in a green query, don't freak out and close
# the connection # the connection
conn = self.conn conn = self.conn
curs = conn.cursor() with conn.cursor() as curs:
self.assertRaises(psycopg2.ProgrammingError, self.assertRaises(psycopg2.ProgrammingError,
curs.execute, "select the unselectable") curs.execute, "select the unselectable")
@ -117,7 +117,7 @@ class GreenTestCase(ConnectingTestCase):
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def test_copy_no_hang(self): def test_copy_no_hang(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.assertRaises(psycopg2.ProgrammingError, self.assertRaises(psycopg2.ProgrammingError,
cur.execute, "copy (select 1) to stdout") cur.execute, "copy (select 1) to stdout")
@ -137,7 +137,7 @@ class GreenTestCase(ConnectingTestCase):
raise conn.OperationalError("bad state from poll: %s" % state) raise conn.OperationalError("bad state from poll: %s" % state)
stub = self.set_stub_wait_callback(self.conn, wait) stub = self.set_stub_wait_callback(self.conn, wait)
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute(""" cur.execute("""
select 1; select 1;
do $$ do $$
@ -203,7 +203,7 @@ class CallbackErrorTestCase(ConnectingTestCase):
for i in range(100): for i in range(100):
self.to_error = None self.to_error = None
cnn = self.connect() cnn = self.connect()
cur = cnn.cursor() with cnn.cursor() as cur:
self.to_error = i self.to_error = i
try: try:
cur.execute("select 1") cur.execute("select 1")
@ -220,7 +220,7 @@ class CallbackErrorTestCase(ConnectingTestCase):
for i in range(100): for i in range(100):
self.to_error = None self.to_error = None
cnn = self.connect() cnn = self.connect()
cur = cnn.cursor('foo') with cnn.cursor('foo') as cur:
self.to_error = i self.to_error = i
try: try:
cur.execute("select 1") cur.execute("select 1")
@ -230,6 +230,9 @@ class CallbackErrorTestCase(ConnectingTestCase):
else: else:
# The query completed # The query completed
return return
finally:
# Don't raise an exception in the cursor context manager.
self.to_error = None
self.fail("you should have had a success or an error by now") self.fail("you should have had a success or an error by now")

View File

@ -33,7 +33,7 @@ except ImportError:
@unittest.skipIf(ip is None, "'ipaddress' module not available") @unittest.skipIf(ip is None, "'ipaddress' module not available")
class NetworkingTestCase(testutils.ConnectingTestCase): class NetworkingTestCase(testutils.ConnectingTestCase):
def test_inet_cast(self): def test_inet_cast(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.register_ipaddress(cur) psycopg2.extras.register_ipaddress(cur)
cur.execute("select null::inet") cur.execute("select null::inet")
@ -51,7 +51,7 @@ class NetworkingTestCase(testutils.ConnectingTestCase):
@testutils.skip_before_postgres(8, 2) @testutils.skip_before_postgres(8, 2)
def test_inet_array_cast(self): def test_inet_array_cast(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.register_ipaddress(cur) psycopg2.extras.register_ipaddress(cur)
cur.execute("select '{NULL,127.0.0.1,::ffff:102:300/128}'::inet[]") cur.execute("select '{NULL,127.0.0.1,::ffff:102:300/128}'::inet[]")
l = cur.fetchone()[0] l = cur.fetchone()[0]
@ -62,7 +62,7 @@ class NetworkingTestCase(testutils.ConnectingTestCase):
self.assert_(isinstance(l[2], ip.IPv6Interface), l) self.assert_(isinstance(l[2], ip.IPv6Interface), l)
def test_inet_adapt(self): def test_inet_adapt(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.register_ipaddress(cur) psycopg2.extras.register_ipaddress(cur)
cur.execute("select %s", [ip.ip_interface('127.0.0.1/24')]) cur.execute("select %s", [ip.ip_interface('127.0.0.1/24')])
@ -72,7 +72,7 @@ class NetworkingTestCase(testutils.ConnectingTestCase):
self.assertEquals(cur.fetchone()[0], '::ffff:102:300/128') self.assertEquals(cur.fetchone()[0], '::ffff:102:300/128')
def test_cidr_cast(self): def test_cidr_cast(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.register_ipaddress(cur) psycopg2.extras.register_ipaddress(cur)
cur.execute("select null::cidr") cur.execute("select null::cidr")
@ -90,7 +90,7 @@ class NetworkingTestCase(testutils.ConnectingTestCase):
@testutils.skip_before_postgres(8, 2) @testutils.skip_before_postgres(8, 2)
def test_cidr_array_cast(self): def test_cidr_array_cast(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.register_ipaddress(cur) psycopg2.extras.register_ipaddress(cur)
cur.execute("select '{NULL,127.0.0.1,::ffff:102:300/128}'::cidr[]") cur.execute("select '{NULL,127.0.0.1,::ffff:102:300/128}'::cidr[]")
l = cur.fetchone()[0] l = cur.fetchone()[0]
@ -101,7 +101,7 @@ class NetworkingTestCase(testutils.ConnectingTestCase):
self.assert_(isinstance(l[2], ip.IPv6Network), l) self.assert_(isinstance(l[2], ip.IPv6Network), l)
def test_cidr_adapt(self): def test_cidr_adapt(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.register_ipaddress(cur) psycopg2.extras.register_ipaddress(cur)
cur.execute("select %s", [ip.ip_network('127.0.0.0/24')]) cur.execute("select %s", [ip.ip_network('127.0.0.0/24')])

View File

@ -154,7 +154,7 @@ class ConnectTestCase(unittest.TestCase):
class ExceptionsTestCase(ConnectingTestCase): class ExceptionsTestCase(ConnectingTestCase):
def test_attributes(self): def test_attributes(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
try: try:
cur.execute("select * from nonexist") cur.execute("select * from nonexist")
except psycopg2.Error as exc: except psycopg2.Error as exc:
@ -165,7 +165,7 @@ class ExceptionsTestCase(ConnectingTestCase):
self.assert_(e.cursor is cur) self.assert_(e.cursor is cur)
def test_diagnostics_attributes(self): def test_diagnostics_attributes(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
try: try:
cur.execute("select * from nonexist") cur.execute("select * from nonexist")
except psycopg2.Error as exc: except psycopg2.Error as exc:
@ -195,7 +195,7 @@ class ExceptionsTestCase(ConnectingTestCase):
def test_diagnostics_life(self): def test_diagnostics_life(self):
def tmp(): def tmp():
cur = self.conn.cursor() with self.conn.cursor() as cur:
try: try:
cur.execute("select * from nonexist") cur.execute("select * from nonexist")
except psycopg2.Error as exc: except psycopg2.Error as exc:

View File

@ -120,7 +120,8 @@ conn.close()
self.listen('foo') self.listen('foo')
pid = int(self.notify('foo').communicate()[0]) pid = int(self.notify('foo').communicate()[0])
self.assertEqual(0, len(self.conn.notifies)) self.assertEqual(0, len(self.conn.notifies))
self.conn.cursor().execute('select 1;') with self.conn.cursor() as cur:
cur.execute('select 1;')
self.assertEqual(1, len(self.conn.notifies)) self.assertEqual(1, len(self.conn.notifies))
self.assertEqual(pid, self.conn.notifies[0][0]) self.assertEqual(pid, self.conn.notifies[0][0])
self.assertEqual('foo', self.conn.notifies[0][1]) self.assertEqual('foo', self.conn.notifies[0][1])

View File

@ -56,7 +56,7 @@ class QuotingTestCase(ConnectingTestCase):
""" """
data += "".join(map(chr, range(1, 127))) data += "".join(map(chr, range(1, 127)))
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("SELECT %s;", (data,)) curs.execute("SELECT %s;", (data,))
res = curs.fetchone()[0] res = curs.fetchone()[0]
@ -64,7 +64,7 @@ class QuotingTestCase(ConnectingTestCase):
self.assert_(not self.conn.notices) self.assert_(not self.conn.notices)
def test_string_null_terminator(self): def test_string_null_terminator(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
data = 'abcd\x01\x00cdefg' data = 'abcd\x01\x00cdefg'
try: try:
@ -84,7 +84,7 @@ class QuotingTestCase(ConnectingTestCase):
else: else:
data += bytes(list(range(256))) data += bytes(list(range(256)))
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("SELECT %s::bytea;", (psycopg2.Binary(data),)) curs.execute("SELECT %s::bytea;", (psycopg2.Binary(data),))
if PY2: if PY2:
res = str(curs.fetchone()[0]) res = str(curs.fetchone()[0])
@ -99,7 +99,7 @@ class QuotingTestCase(ConnectingTestCase):
self.assert_(not self.conn.notices) self.assert_(not self.conn.notices)
def test_unicode(self): def test_unicode(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("SHOW server_encoding") curs.execute("SHOW server_encoding")
server_encoding = curs.fetchone()[0] server_encoding = curs.fetchone()[0]
if server_encoding != "UTF8": if server_encoding != "UTF8":
@ -123,7 +123,7 @@ class QuotingTestCase(ConnectingTestCase):
def test_latin1(self): def test_latin1(self):
self.conn.set_client_encoding('LATIN1') self.conn.set_client_encoding('LATIN1')
curs = self.conn.cursor() with self.conn.cursor() as curs:
if PY2: if PY2:
data = ''.join(map(chr, range(32, 127) + range(160, 256))) data = ''.join(map(chr, range(32, 127) + range(160, 256)))
else: else:
@ -138,7 +138,8 @@ class QuotingTestCase(ConnectingTestCase):
# as unicode # as unicode
if PY2: if PY2:
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE, self.conn) psycopg2.extensions.register_type(
psycopg2.extensions.UNICODE, self.conn)
data = data.decode('latin1') data = data.decode('latin1')
curs.execute("SELECT %s::text;", (data,)) curs.execute("SELECT %s::text;", (data,))
@ -148,7 +149,7 @@ class QuotingTestCase(ConnectingTestCase):
def test_koi8(self): def test_koi8(self):
self.conn.set_client_encoding('KOI8') self.conn.set_client_encoding('KOI8')
curs = self.conn.cursor() with self.conn.cursor() as curs:
if PY2: if PY2:
data = ''.join(map(chr, range(32, 127) + range(128, 256))) data = ''.join(map(chr, range(32, 127) + range(128, 256)))
else: else:
@ -176,7 +177,7 @@ class QuotingTestCase(ConnectingTestCase):
conn = self.connect() conn = self.connect()
conn.set_client_encoding('UNICODE') conn.set_client_encoding('UNICODE')
psycopg2.extensions.register_type(psycopg2.extensions.BYTES, conn) psycopg2.extensions.register_type(psycopg2.extensions.BYTES, conn)
curs = conn.cursor() with conn.cursor() as curs:
curs.execute("select %s::text", (snowman,)) curs.execute("select %s::text", (snowman,))
x = curs.fetchone()[0] x = curs.fetchone()[0]
self.assert_(isinstance(x, bytes)) self.assert_(isinstance(x, bytes))

View File

@ -73,8 +73,7 @@ class ReplicationTestCase(ConnectingTestCase):
conn = self.connect() conn = self.connect()
if conn is None: if conn is None:
return return
cur = conn.cursor() with conn.cursor() as cur:
try: try:
cur.execute("DROP TABLE dummy1") cur.execute("DROP TABLE dummy1")
except psycopg2.ProgrammingError: except psycopg2.ProgrammingError:
@ -90,7 +89,7 @@ class ReplicationTest(ReplicationTestCase):
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection) conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
if conn is None: if conn is None:
return return
cur = conn.cursor() with conn.cursor() as cur:
cur.execute("IDENTIFY_SYSTEM") cur.execute("IDENTIFY_SYSTEM")
cur.fetchall() cur.fetchall()
@ -104,7 +103,7 @@ class ReplicationTest(ReplicationTestCase):
connection_factory=PhysicalReplicationConnection) connection_factory=PhysicalReplicationConnection)
if conn is None: if conn is None:
return return
cur = conn.cursor() with conn.cursor() as cur:
cur.execute("IDENTIFY_SYSTEM") cur.execute("IDENTIFY_SYSTEM")
cur.fetchall() cur.fetchall()
@ -113,7 +112,7 @@ class ReplicationTest(ReplicationTestCase):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection) conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
if conn is None: if conn is None:
return return
cur = conn.cursor() with conn.cursor() as cur:
cur.execute("IDENTIFY_SYSTEM") cur.execute("IDENTIFY_SYSTEM")
cur.fetchall() cur.fetchall()
@ -122,8 +121,7 @@ class ReplicationTest(ReplicationTestCase):
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection) conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
if conn is None: if conn is None:
return return
cur = conn.cursor() with conn.cursor() as cur:
self.create_replication_slot(cur) self.create_replication_slot(cur)
self.assertRaises( self.assertRaises(
psycopg2.ProgrammingError, self.create_replication_slot, cur) psycopg2.ProgrammingError, self.create_replication_slot, cur)
@ -134,8 +132,7 @@ class ReplicationTest(ReplicationTestCase):
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection) conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
if conn is None: if conn is None:
return return
cur = conn.cursor() with conn.cursor() as cur:
self.assertRaises(psycopg2.ProgrammingError, self.assertRaises(psycopg2.ProgrammingError,
cur.start_replication, self.slot) cur.start_replication, self.slot)
@ -148,8 +145,7 @@ class ReplicationTest(ReplicationTestCase):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection) conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
if conn is None: if conn is None:
return return
cur = conn.cursor() with conn.cursor() as cur:
self.create_replication_slot(cur, output_plugin='test_decoding') self.create_replication_slot(cur, output_plugin='test_decoding')
cur.start_replication_expert( cur.start_replication_expert(
sql.SQL("START_REPLICATION SLOT {slot} LOGICAL 0/00000000").format( sql.SQL("START_REPLICATION SLOT {slot} LOGICAL 0/00000000").format(
@ -161,8 +157,7 @@ class ReplicationTest(ReplicationTestCase):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection) conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
if conn is None: if conn is None:
return return
cur = conn.cursor() with conn.cursor() as cur:
self.create_replication_slot(cur, output_plugin='test_decoding') self.create_replication_slot(cur, output_plugin='test_decoding')
self.make_replication_events() self.make_replication_events()
@ -208,8 +203,7 @@ class ReplicationTest(ReplicationTestCase):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection) conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
if conn is None: if conn is None:
return return
cur = conn.cursor() with conn.cursor() as cur:
self.create_replication_slot(cur, output_plugin='test_decoding') self.create_replication_slot(cur, output_plugin='test_decoding')
self.make_replication_events() self.make_replication_events()
@ -230,8 +224,7 @@ class AsyncReplicationTest(ReplicationTestCase):
if conn is None: if conn is None:
return return
cur = conn.cursor() with conn.cursor() as cur:
self.create_replication_slot(cur, output_plugin='test_decoding') self.create_replication_slot(cur, output_plugin='test_decoding')
self.wait(cur) self.wait(cur)

View File

@ -117,7 +117,7 @@ class SqlFormatTests(ConnectingTestCase):
sql.SQL("select {0};").format(sql.Literal(Foo())).as_string, self.conn) sql.SQL("select {0};").format(sql.Literal(Foo())).as_string, self.conn)
def test_execute(self): def test_execute(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute(""" cur.execute("""
create table test_compose ( create table test_compose (
id serial primary key, id serial primary key,
@ -134,7 +134,7 @@ class SqlFormatTests(ConnectingTestCase):
self.assertEqual(cur.fetchall(), [(10, 'a', 'b', 'c')]) self.assertEqual(cur.fetchall(), [(10, 'a', 'b', 'c')])
def test_executemany(self): def test_executemany(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute(""" cur.execute("""
create table test_compose ( create table test_compose (
id serial primary key, id serial primary key,
@ -154,7 +154,7 @@ class SqlFormatTests(ConnectingTestCase):
@skip_copy_if_green @skip_copy_if_green
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def test_copy(self): def test_copy(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute(""" cur.execute("""
create table test_compose ( create table test_compose (
id serial primary key, id serial primary key,

View File

@ -37,7 +37,7 @@ class TransactionTests(ConnectingTestCase):
def setUp(self): def setUp(self):
ConnectingTestCase.setUp(self) ConnectingTestCase.setUp(self)
self.conn.set_isolation_level(ISOLATION_LEVEL_SERIALIZABLE) self.conn.set_isolation_level(ISOLATION_LEVEL_SERIALIZABLE)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute(''' curs.execute('''
CREATE TEMPORARY TABLE table1 ( CREATE TEMPORARY TABLE table1 (
id int PRIMARY KEY id int PRIMARY KEY
@ -55,7 +55,7 @@ class TransactionTests(ConnectingTestCase):
def test_rollback(self): def test_rollback(self):
# Test that rollback undoes changes # Test that rollback undoes changes
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute('INSERT INTO table2 VALUES (2, 1)') curs.execute('INSERT INTO table2 VALUES (2, 1)')
# Rollback takes us from BEGIN state to READY state # Rollback takes us from BEGIN state to READY state
self.assertEqual(self.conn.status, STATUS_BEGIN) self.assertEqual(self.conn.status, STATUS_BEGIN)
@ -66,7 +66,7 @@ class TransactionTests(ConnectingTestCase):
def test_commit(self): def test_commit(self):
# Test that commit stores changes # Test that commit stores changes
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute('INSERT INTO table2 VALUES (2, 1)') curs.execute('INSERT INTO table2 VALUES (2, 1)')
# Rollback takes us from BEGIN state to READY state # Rollback takes us from BEGIN state to READY state
self.assertEqual(self.conn.status, STATUS_BEGIN) self.assertEqual(self.conn.status, STATUS_BEGIN)
@ -80,7 +80,7 @@ class TransactionTests(ConnectingTestCase):
def test_failed_commit(self): def test_failed_commit(self):
# Test that we can recover from a failed commit. # Test that we can recover from a failed commit.
# We use a deferred constraint to cause a failure on commit. # We use a deferred constraint to cause a failure on commit.
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute('SET CONSTRAINTS table2__table1_id__fk DEFERRED') curs.execute('SET CONSTRAINTS table2__table1_id__fk DEFERRED')
curs.execute('INSERT INTO table2 VALUES (2, 42)') curs.execute('INSERT INTO table2 VALUES (2, 42)')
# The commit should fail, and move the cursor back to READY state # The commit should fail, and move the cursor back to READY state
@ -103,7 +103,7 @@ class DeadlockSerializationTests(ConnectingTestCase):
def setUp(self): def setUp(self):
ConnectingTestCase.setUp(self) ConnectingTestCase.setUp(self)
curs = self.conn.cursor() with self.conn.cursor() as curs:
# Drop table if it already exists # Drop table if it already exists
try: try:
curs.execute("DROP TABLE table1") curs.execute("DROP TABLE table1")
@ -126,7 +126,7 @@ class DeadlockSerializationTests(ConnectingTestCase):
self.conn.commit() self.conn.commit()
def tearDown(self): def tearDown(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("DROP TABLE table1") curs.execute("DROP TABLE table1")
curs.execute("DROP TABLE table2") curs.execute("DROP TABLE table2")
self.conn.commit() self.conn.commit()
@ -142,7 +142,7 @@ class DeadlockSerializationTests(ConnectingTestCase):
def task1(): def task1():
try: try:
conn = self.connect() conn = self.connect()
curs = conn.cursor() with conn.cursor() as curs:
curs.execute("LOCK table1 IN ACCESS EXCLUSIVE MODE") curs.execute("LOCK table1 IN ACCESS EXCLUSIVE MODE")
step1.set() step1.set()
step2.wait() step2.wait()
@ -155,7 +155,7 @@ class DeadlockSerializationTests(ConnectingTestCase):
def task2(): def task2():
try: try:
conn = self.connect() conn = self.connect()
curs = conn.cursor() with conn.cursor() as curs:
step1.wait() step1.wait()
curs.execute("LOCK table2 IN ACCESS EXCLUSIVE MODE") curs.execute("LOCK table2 IN ACCESS EXCLUSIVE MODE")
step2.set() step2.set()
@ -190,7 +190,7 @@ class DeadlockSerializationTests(ConnectingTestCase):
def task1(): def task1():
try: try:
conn = self.connect() conn = self.connect()
curs = conn.cursor() with conn.cursor() as curs:
curs.execute("SELECT name FROM table1 WHERE id = 1") curs.execute("SELECT name FROM table1 WHERE id = 1")
curs.fetchall() curs.fetchall()
step1.set() step1.set()
@ -205,7 +205,7 @@ class DeadlockSerializationTests(ConnectingTestCase):
def task2(): def task2():
try: try:
conn = self.connect() conn = self.connect()
curs = conn.cursor() with conn.cursor() as curs:
step1.wait() step1.wait()
curs.execute("UPDATE table1 SET name='task2' WHERE id = 1") curs.execute("UPDATE table1 SET name='task2' WHERE id = 1")
conn.commit() conn.commit()
@ -240,7 +240,7 @@ class QueryCancellationTests(ConnectingTestCase):
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def test_statement_timeout(self): def test_statement_timeout(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
# Set a low statement timeout, then sleep for a longer period. # Set a low statement timeout, then sleep for a longer period.
curs.execute('SET statement_timeout TO 10') curs.execute('SET statement_timeout TO 10')
self.assertRaises(psycopg2.extensions.QueryCanceledError, self.assertRaises(psycopg2.extensions.QueryCanceledError,

View File

@ -41,7 +41,7 @@ class TypesBasicTests(ConnectingTestCase):
"""Test that all type conversions are working.""" """Test that all type conversions are working."""
def execute(self, *args): def execute(self, *args):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute(*args) curs.execute(*args)
return curs.fetchone()[0] return curs.fetchone()[0]
@ -156,7 +156,7 @@ class TypesBasicTests(ConnectingTestCase):
def testEmptyArrayRegression(self): def testEmptyArrayRegression(self):
# ticket #42 # ticket #42
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute( curs.execute(
"create table array_test " "create table array_test "
"(id integer, col timestamp without time zone[])") "(id integer, col timestamp without time zone[])")
@ -164,7 +164,8 @@ class TypesBasicTests(ConnectingTestCase):
curs.execute("insert into array_test values (%s, %s)", curs.execute("insert into array_test values (%s, %s)",
(1, [datetime.date(2011, 2, 14)])) (1, [datetime.date(2011, 2, 14)]))
curs.execute("select col from array_test where id = 1") curs.execute("select col from array_test where id = 1")
self.assertEqual(curs.fetchone()[0], [datetime.datetime(2011, 2, 14, 0, 0)]) self.assertEqual(
curs.fetchone()[0], [datetime.datetime(2011, 2, 14, 0, 0)])
curs.execute("insert into array_test values (%s, %s)", (2, [])) curs.execute("insert into array_test values (%s, %s)", (2, []))
curs.execute("select col from array_test where id = 2") curs.execute("select col from array_test where id = 2")
@ -173,7 +174,7 @@ class TypesBasicTests(ConnectingTestCase):
@testutils.skip_before_postgres(8, 4) @testutils.skip_before_postgres(8, 4)
def testNestedEmptyArray(self): def testNestedEmptyArray(self):
# issue #788 # issue #788
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select 10 = any(%s::int[])", ([[]], )) curs.execute("select 10 = any(%s::int[])", ([[]], ))
self.assertFalse(curs.fetchone()[0]) self.assertFalse(curs.fetchone()[0])
@ -204,14 +205,14 @@ class TypesBasicTests(ConnectingTestCase):
self.failUnlessEqual(ss, r) self.failUnlessEqual(ss, r)
def testArrayMalformed(self): def testArrayMalformed(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
ss = ['', '{', '{}}', '{' * 20 + '}' * 20] ss = ['', '{', '{}}', '{' * 20 + '}' * 20]
for s in ss: for s in ss:
self.assertRaises(psycopg2.DataError, self.assertRaises(psycopg2.DataError,
psycopg2.extensions.STRINGARRAY, s.encode('utf8'), curs) psycopg2.extensions.STRINGARRAY, s.encode('utf8'), curs)
def testTextArray(self): def testTextArray(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select '{a,b,c}'::text[]") curs.execute("select '{a,b,c}'::text[]")
x = curs.fetchone()[0] x = curs.fetchone()[0]
self.assert_(isinstance(x[0], str)) self.assert_(isinstance(x[0], str))
@ -220,7 +221,7 @@ class TypesBasicTests(ConnectingTestCase):
def testUnicodeArray(self): def testUnicodeArray(self):
psycopg2.extensions.register_type( psycopg2.extensions.register_type(
psycopg2.extensions.UNICODEARRAY, self.conn) psycopg2.extensions.UNICODEARRAY, self.conn)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select '{a,b,c}'::text[]") curs.execute("select '{a,b,c}'::text[]")
x = curs.fetchone()[0] x = curs.fetchone()[0]
self.assert_(isinstance(x[0], text_type)) self.assert_(isinstance(x[0], text_type))
@ -229,7 +230,7 @@ class TypesBasicTests(ConnectingTestCase):
def testBytesArray(self): def testBytesArray(self):
psycopg2.extensions.register_type( psycopg2.extensions.register_type(
psycopg2.extensions.BYTESARRAY, self.conn) psycopg2.extensions.BYTESARRAY, self.conn)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select '{a,b,c}'::text[]") curs.execute("select '{a,b,c}'::text[]")
x = curs.fetchone()[0] x = curs.fetchone()[0]
self.assert_(isinstance(x[0], bytes)) self.assert_(isinstance(x[0], bytes))
@ -237,7 +238,7 @@ class TypesBasicTests(ConnectingTestCase):
@testutils.skip_before_postgres(8, 2) @testutils.skip_before_postgres(8, 2)
def testArrayOfNulls(self): def testArrayOfNulls(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute(""" curs.execute("""
create table na ( create table na (
texta text[], texta text[],
@ -273,7 +274,7 @@ class TypesBasicTests(ConnectingTestCase):
@testutils.skip_before_postgres(8, 2) @testutils.skip_before_postgres(8, 2)
def testNestedArrays(self): def testNestedArrays(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
for a in [ for a in [
[[1]], [[1]],
[[None]], [[None]],

View File

@ -45,7 +45,7 @@ class TypesExtrasTests(ConnectingTestCase):
"""Test that all type conversions are working.""" """Test that all type conversions are working."""
def execute(self, *args): def execute(self, *args):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute(*args) curs.execute(*args)
return curs.fetchone()[0] return curs.fetchone()[0]
@ -231,7 +231,7 @@ class HstoreTestCase(ConnectingTestCase):
@skip_if_no_hstore @skip_if_no_hstore
def test_register_conn(self): def test_register_conn(self):
register_hstore(self.conn) register_hstore(self.conn)
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore")
t = cur.fetchone() t = cur.fetchone()
self.assert_(t[0] is None) self.assert_(t[0] is None)
@ -240,7 +240,7 @@ class HstoreTestCase(ConnectingTestCase):
@skip_if_no_hstore @skip_if_no_hstore
def test_register_curs(self): def test_register_curs(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
register_hstore(cur) register_hstore(cur)
cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore")
t = cur.fetchone() t = cur.fetchone()
@ -252,7 +252,7 @@ class HstoreTestCase(ConnectingTestCase):
@skip_from_python(3) @skip_from_python(3)
def test_register_unicode(self): def test_register_unicode(self):
register_hstore(self.conn, unicode=True) register_hstore(self.conn, unicode=True)
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore")
t = cur.fetchone() t = cur.fetchone()
self.assert_(t[0] is None) self.assert_(t[0] is None)
@ -268,7 +268,7 @@ class HstoreTestCase(ConnectingTestCase):
register_hstore(self.conn, globally=True) register_hstore(self.conn, globally=True)
conn2 = self.connect() conn2 = self.connect()
try: try:
cur2 = self.conn.cursor() with self.conn.cursor() as cur2:
cur2.execute("select 'a => b'::hstore") cur2.execute("select 'a => b'::hstore")
r = cur2.fetchone() r = cur2.fetchone()
self.assert_(isinstance(r[0], dict)) self.assert_(isinstance(r[0], dict))
@ -278,8 +278,7 @@ class HstoreTestCase(ConnectingTestCase):
@skip_if_no_hstore @skip_if_no_hstore
def test_roundtrip(self): def test_roundtrip(self):
register_hstore(self.conn) register_hstore(self.conn)
cur = self.conn.cursor() with self.conn.cursor() as cur:
def ok(d): def ok(d):
cur.execute("select %s", (d,)) cur.execute("select %s", (d,))
d1 = cur.fetchone()[0] d1 = cur.fetchone()[0]
@ -299,7 +298,8 @@ class HstoreTestCase(ConnectingTestCase):
if PY2: if PY2:
ab = map(chr, range(32, 127) + range(160, 255)) ab = map(chr, range(32, 127) + range(160, 255))
else: else:
ab = bytes(list(range(32, 127)) + list(range(160, 255))).decode('latin1') ab = bytes(
list(range(32, 127)) + list(range(160, 255))).decode('latin1')
ok({''.join(ab): ''.join(ab)}) ok({''.join(ab): ''.join(ab)})
ok(dict(zip(ab, ab))) ok(dict(zip(ab, ab)))
@ -308,8 +308,7 @@ class HstoreTestCase(ConnectingTestCase):
@skip_from_python(3) @skip_from_python(3)
def test_roundtrip_unicode(self): def test_roundtrip_unicode(self):
register_hstore(self.conn, unicode=True) register_hstore(self.conn, unicode=True)
cur = self.conn.cursor() with self.conn.cursor() as cur:
def ok(d): def ok(d):
cur.execute("select %s", (d,)) cur.execute("select %s", (d,))
d1 = cur.fetchone()[0] d1 = cur.fetchone()[0]
@ -330,7 +329,7 @@ class HstoreTestCase(ConnectingTestCase):
@skip_if_no_hstore @skip_if_no_hstore
@restore_types @restore_types
def test_oid(self): def test_oid(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 'hstore'::regtype::oid") cur.execute("select 'hstore'::regtype::oid")
oid = cur.fetchone()[0] oid = cur.fetchone()[0]
@ -363,7 +362,7 @@ class HstoreTestCase(ConnectingTestCase):
ds.append({''.join(ab): ''.join(ab)}) ds.append({''.join(ab): ''.join(ab)})
ds.append(dict(zip(ab, ab))) ds.append(dict(zip(ab, ab)))
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select %s", (ds,)) cur.execute("select %s", (ds,))
ds1 = cur.fetchone()[0] ds1 = cur.fetchone()[0]
self.assertEqual(ds, ds1) self.assertEqual(ds, ds1)
@ -372,7 +371,7 @@ class HstoreTestCase(ConnectingTestCase):
@skip_before_postgres(8, 3) @skip_before_postgres(8, 3)
def test_array_cast(self): def test_array_cast(self):
register_hstore(self.conn) register_hstore(self.conn)
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select array['a=>1'::hstore, 'b=>2'::hstore];") cur.execute("select array['a=>1'::hstore, 'b=>2'::hstore];")
a = cur.fetchone()[0] a = cur.fetchone()[0]
self.assertEqual(a, [{'a': '1'}, {'b': '2'}]) self.assertEqual(a, [{'a': '1'}, {'b': '2'}])
@ -380,7 +379,7 @@ class HstoreTestCase(ConnectingTestCase):
@skip_if_no_hstore @skip_if_no_hstore
@restore_types @restore_types
def test_array_cast_oid(self): def test_array_cast_oid(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 'hstore'::regtype::oid, 'hstore[]'::regtype::oid") cur.execute("select 'hstore'::regtype::oid, 'hstore[]'::regtype::oid")
oid, aoid = cur.fetchone() oid, aoid = cur.fetchone()
@ -399,7 +398,7 @@ class HstoreTestCase(ConnectingTestCase):
conn = self.connect(connection_factory=RealDictConnection) conn = self.connect(connection_factory=RealDictConnection)
try: try:
register_hstore(conn) register_hstore(conn)
curs = conn.cursor() with conn.cursor() as curs:
curs.execute("select ''::hstore as x") curs.execute("select ''::hstore as x")
self.assertEqual(curs.fetchone()['x'], {}) self.assertEqual(curs.fetchone()['x'], {})
finally: finally:
@ -407,7 +406,7 @@ class HstoreTestCase(ConnectingTestCase):
conn = self.connect(connection_factory=RealDictConnection) conn = self.connect(connection_factory=RealDictConnection)
try: try:
curs = conn.cursor() with conn.cursor() as curs:
register_hstore(curs) register_hstore(curs)
curs.execute("select ''::hstore as x") curs.execute("select ''::hstore as x")
self.assertEqual(curs.fetchone()['x'], {}) self.assertEqual(curs.fetchone()['x'], {})
@ -431,7 +430,7 @@ def skip_if_no_composite(f):
class AdaptTypeTestCase(ConnectingTestCase): class AdaptTypeTestCase(ConnectingTestCase):
@skip_if_no_composite @skip_if_no_composite
def test_none_in_record(self): def test_none_in_record(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
s = curs.mogrify("SELECT %s;", [(42, None)]) s = curs.mogrify("SELECT %s;", [(42, None)])
self.assertEqual(b"SELECT (42, NULL);", s) self.assertEqual(b"SELECT (42, NULL);", s)
curs.execute("SELECT %s;", [(42, None)]) curs.execute("SELECT %s;", [(42, None)])
@ -448,8 +447,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
def getquoted(self): def getquoted(self):
return "NOPE!" return "NOPE!"
curs = self.conn.cursor() with self.conn.cursor() as curs:
orig_adapter = ext.adapters[type(None), ext.ISQLQuote] orig_adapter = ext.adapters[type(None), ext.ISQLQuote]
try: try:
ext.register_adapter(type(None), WonkyAdapter) ext.register_adapter(type(None), WonkyAdapter)
@ -502,7 +500,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
self.assertEqual(t.attnames, ['anint', 'astring', 'adate']) self.assertEqual(t.attnames, ['anint', 'astring', 'adate'])
self.assertEqual(t.atttypes, [23, 25, 1082]) self.assertEqual(t.atttypes, [23, 25, 1082])
curs = self.conn.cursor() with self.conn.cursor() as curs:
r = (10, 'hello', date(2011, 1, 2)) r = (10, 'hello', date(2011, 1, 2))
curs.execute("select %s::type_isd;", (r,)) curs.execute("select %s::type_isd;", (r,))
v = curs.fetchone()[0] v = curs.fetchone()[0]
@ -519,7 +517,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
def test_empty_string(self): def test_empty_string(self):
# issue #141 # issue #141
self._create_type("type_ss", [('s1', 'text'), ('s2', 'text')]) self._create_type("type_ss", [('s1', 'text'), ('s2', 'text')])
curs = self.conn.cursor() with self.conn.cursor() as curs:
psycopg2.extras.register_composite("type_ss", curs) psycopg2.extras.register_composite("type_ss", curs)
def ok(t): def ok(t):
@ -548,7 +546,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
psycopg2.extras.register_composite("type_r_dt", self.conn) psycopg2.extras.register_composite("type_r_dt", self.conn)
psycopg2.extras.register_composite("type_r_ft", self.conn) psycopg2.extras.register_composite("type_r_ft", self.conn)
curs = self.conn.cursor() with self.conn.cursor() as curs:
r = (0.25, (date(2011, 1, 2), (42, "hello"))) r = (0.25, (date(2011, 1, 2), (42, "hello")))
curs.execute("select %s::type_r_ft;", (r,)) curs.execute("select %s::type_r_ft;", (r,))
v = curs.fetchone()[0] v = curs.fetchone()[0]
@ -560,8 +558,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
def test_register_on_cursor(self): def test_register_on_cursor(self):
self._create_type("type_ii", [("a", "integer"), ("b", "integer")]) self._create_type("type_ii", [("a", "integer"), ("b", "integer")])
curs1 = self.conn.cursor() with self.conn.cursor() as curs1, self.conn.cursor() as curs2:
curs2 = self.conn.cursor()
psycopg2.extras.register_composite("type_ii", curs1) psycopg2.extras.register_composite("type_ii", curs1)
curs1.execute("select (1,2)::type_ii") curs1.execute("select (1,2)::type_ii")
self.assertEqual(curs1.fetchone()[0], (1, 2)) self.assertEqual(curs1.fetchone()[0], (1, 2))
@ -576,8 +573,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
conn2 = self.connect() conn2 = self.connect()
try: try:
psycopg2.extras.register_composite("type_ii", conn1) psycopg2.extras.register_composite("type_ii", conn1)
curs1 = conn1.cursor() with conn1.cursor() as curs1, conn2.cursor() as curs2:
curs2 = conn2.cursor()
curs1.execute("select (1,2)::type_ii") curs1.execute("select (1,2)::type_ii")
self.assertEqual(curs1.fetchone()[0], (1, 2)) self.assertEqual(curs1.fetchone()[0], (1, 2))
curs2.execute("select (1,2)::type_ii") curs2.execute("select (1,2)::type_ii")
@ -595,8 +591,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
conn2 = self.connect() conn2 = self.connect()
try: try:
psycopg2.extras.register_composite("type_ii", conn1, globally=True) psycopg2.extras.register_composite("type_ii", conn1, globally=True)
curs1 = conn1.cursor() with conn1.cursor() as curs1, conn2.cursor() as curs2:
curs2 = conn2.cursor()
curs1.execute("select (1,2)::type_ii") curs1.execute("select (1,2)::type_ii")
self.assertEqual(curs1.fetchone()[0], (1, 2)) self.assertEqual(curs1.fetchone()[0], (1, 2))
curs2.execute("select (1,2)::type_ii") curs2.execute("select (1,2)::type_ii")
@ -608,7 +603,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
@skip_if_no_composite @skip_if_no_composite
def test_composite_namespace(self): def test_composite_namespace(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute(""" curs.execute("""
select nspname from pg_namespace select nspname from pg_namespace
where nspname = 'typens'; where nspname = 'typens';
@ -633,7 +628,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
t = psycopg2.extras.register_composite("type_isd", self.conn) t = psycopg2.extras.register_composite("type_isd", self.conn)
curs = self.conn.cursor() with self.conn.cursor() as curs:
r1 = (10, 'hello', date(2011, 1, 2)) r1 = (10, 'hello', date(2011, 1, 2))
r2 = (20, 'world', date(2011, 1, 3)) r2 = (20, 'world', date(2011, 1, 3))
curs.execute("select %s::type_isd[];", ([r1, r2],)) curs.execute("select %s::type_isd[];", ([r1, r2],))
@ -652,7 +647,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
def test_wrong_schema(self): def test_wrong_schema(self):
oid = self._create_type("type_ii", [("a", "integer"), ("b", "integer")]) oid = self._create_type("type_ii", [("a", "integer"), ("b", "integer")])
c = CompositeCaster('type_ii', oid, [('a', 23), ('b', 23), ('c', 23)]) c = CompositeCaster('type_ii', oid, [('a', 23), ('b', 23), ('c', 23)])
curs = self.conn.cursor() with self.conn.cursor() as curs:
psycopg2.extensions.register_type(c.typecaster, curs) psycopg2.extensions.register_type(c.typecaster, curs)
curs.execute("select (1,2)::type_ii") curs.execute("select (1,2)::type_ii")
self.assertRaises(psycopg2.DataError, curs.fetchone) self.assertRaises(psycopg2.DataError, curs.fetchone)
@ -661,7 +656,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
@skip_if_no_composite @skip_if_no_composite
@skip_before_postgres(8, 4) @skip_before_postgres(8, 4)
def test_from_tables(self): def test_from_tables(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("""create table ctest1 ( curs.execute("""create table ctest1 (
id integer primary key, id integer primary key,
temp int, temp int,
@ -710,7 +705,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
conn = self.connect(connection_factory=RealDictConnection) conn = self.connect(connection_factory=RealDictConnection)
try: try:
register_composite('type_ii', conn) register_composite('type_ii', conn)
curs = conn.cursor() with conn.cursor() as curs:
curs.execute("select '(1,2)'::type_ii as x") curs.execute("select '(1,2)'::type_ii as x")
self.assertEqual(curs.fetchone()['x'], (1, 2)) self.assertEqual(curs.fetchone()['x'], (1, 2))
finally: finally:
@ -718,7 +713,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
conn = self.connect(connection_factory=RealDictConnection) conn = self.connect(connection_factory=RealDictConnection)
try: try:
curs = conn.cursor() with conn.cursor() as curs:
register_composite('type_ii', conn) register_composite('type_ii', conn)
curs.execute("select '(1,2)'::type_ii as x") curs.execute("select '(1,2)'::type_ii as x")
self.assertEqual(curs.fetchone()['x'], (1, 2)) self.assertEqual(curs.fetchone()['x'], (1, 2))
@ -739,7 +734,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
self.assertEqual(t.name, 'type_isd') self.assertEqual(t.name, 'type_isd')
self.assertEqual(t.oid, oid) self.assertEqual(t.oid, oid)
curs = self.conn.cursor() with self.conn.cursor() as curs:
r = (10, 'hello', date(2011, 1, 2)) r = (10, 'hello', date(2011, 1, 2))
curs.execute("select %s::type_isd;", (r,)) curs.execute("select %s::type_isd;", (r,))
v = curs.fetchone()[0] v = curs.fetchone()[0]
@ -749,7 +744,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
self.assertEqual(v['adate'], date(2011, 1, 2)) self.assertEqual(v['adate'], date(2011, 1, 2))
def _create_type(self, name, fields): def _create_type(self, name, fields):
curs = self.conn.cursor() with self.conn.cursor() as curs:
try: try:
curs.execute("drop type %s cascade;" % name) curs.execute("drop type %s cascade;" % name)
except psycopg2.ProgrammingError: except psycopg2.ProgrammingError:
@ -776,7 +771,7 @@ def skip_if_no_json_type(f):
"""Skip a test if PostgreSQL json type is not available""" """Skip a test if PostgreSQL json type is not available"""
@wraps(f) @wraps(f)
def skip_if_no_json_type_(self): def skip_if_no_json_type_(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select oid from pg_type where typname = 'json'") curs.execute("select oid from pg_type where typname = 'json'")
if not curs.fetchone(): if not curs.fetchone():
return self.skipTest("json not available in test database") return self.skipTest("json not available in test database")
@ -791,7 +786,7 @@ class JsonTestCase(ConnectingTestCase):
objs = [None, "te'xt", 123, 123.45, objs = [None, "te'xt", 123, 123.45,
u'\xe0\u20ac', ['a', 100], {'a': 100}] u'\xe0\u20ac', ['a', 100], {'a': 100}]
curs = self.conn.cursor() with self.conn.cursor() as curs:
for obj in enumerate(objs): for obj in enumerate(objs):
self.assertQuotedEqual(curs.mogrify("%s", (Json(obj),)), self.assertQuotedEqual(curs.mogrify("%s", (Json(obj),)),
psycopg2.extensions.QuotedString(json.dumps(obj)).getquoted()) psycopg2.extensions.QuotedString(json.dumps(obj)).getquoted())
@ -803,7 +798,7 @@ class JsonTestCase(ConnectingTestCase):
return float(obj) return float(obj)
return json.JSONEncoder.default(self, obj) return json.JSONEncoder.default(self, obj)
curs = self.conn.cursor() with self.conn.cursor() as curs:
obj = Decimal('123.45') obj = Decimal('123.45')
def dumps(obj): def dumps(obj):
@ -822,7 +817,7 @@ class JsonTestCase(ConnectingTestCase):
def dumps(self, obj): def dumps(self, obj):
return json.dumps(obj, cls=DecimalEncoder) return json.dumps(obj, cls=DecimalEncoder)
curs = self.conn.cursor() with self.conn.cursor() as curs:
obj = Decimal('123.45') obj = Decimal('123.45')
self.assertQuotedEqual(curs.mogrify("%s", (MyJson(obj),)), b"'123.45'") self.assertQuotedEqual(curs.mogrify("%s", (MyJson(obj),)), b"'123.45'")
@ -830,13 +825,13 @@ class JsonTestCase(ConnectingTestCase):
def test_register_on_dict(self): def test_register_on_dict(self):
psycopg2.extensions.register_adapter(dict, Json) psycopg2.extensions.register_adapter(dict, Json)
curs = self.conn.cursor() with self.conn.cursor() as curs:
obj = {'a': 123} obj = {'a': 123}
self.assertQuotedEqual( self.assertQuotedEqual(
curs.mogrify("%s", (obj,)), b"""'{"a": 123}'""") curs.mogrify("%s", (obj,)), b"""'{"a": 123}'""")
def test_type_not_available(self): def test_type_not_available(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select oid from pg_type where typname = 'json'") curs.execute("select oid from pg_type where typname = 'json'")
if curs.fetchone(): if curs.fetchone():
return self.skipTest("json available in test database") return self.skipTest("json available in test database")
@ -846,8 +841,7 @@ class JsonTestCase(ConnectingTestCase):
@skip_before_postgres(9, 2) @skip_before_postgres(9, 2)
def test_default_cast(self): def test_default_cast(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("""select '{"a": 100.0, "b": null}'::json""") curs.execute("""select '{"a": 100.0, "b": null}'::json""")
self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None})
@ -857,13 +851,13 @@ class JsonTestCase(ConnectingTestCase):
@skip_if_no_json_type @skip_if_no_json_type
def test_register_on_connection(self): def test_register_on_connection(self):
psycopg2.extras.register_json(self.conn) psycopg2.extras.register_json(self.conn)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("""select '{"a": 100.0, "b": null}'::json""") curs.execute("""select '{"a": 100.0, "b": null}'::json""")
self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None})
@skip_if_no_json_type @skip_if_no_json_type
def test_register_on_cursor(self): def test_register_on_cursor(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
psycopg2.extras.register_json(curs) psycopg2.extras.register_json(curs)
curs.execute("""select '{"a": 100.0, "b": null}'::json""") curs.execute("""select '{"a": 100.0, "b": null}'::json""")
self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None})
@ -872,7 +866,7 @@ class JsonTestCase(ConnectingTestCase):
@restore_types @restore_types
def test_register_globally(self): def test_register_globally(self):
new, newa = psycopg2.extras.register_json(self.conn, globally=True) new, newa = psycopg2.extras.register_json(self.conn, globally=True)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("""select '{"a": 100.0, "b": null}'::json""") curs.execute("""select '{"a": 100.0, "b": null}'::json""")
self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None})
@ -883,7 +877,7 @@ class JsonTestCase(ConnectingTestCase):
def loads(s): def loads(s):
return json.loads(s, parse_float=Decimal) return json.loads(s, parse_float=Decimal)
psycopg2.extras.register_json(self.conn, loads=loads) psycopg2.extras.register_json(self.conn, loads=loads)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("""select '{"a": 100.0, "b": null}'::json""") curs.execute("""select '{"a": 100.0, "b": null}'::json""")
data = curs.fetchone()[0] data = curs.fetchone()[0]
self.assert_(isinstance(data['a'], Decimal)) self.assert_(isinstance(data['a'], Decimal))
@ -899,7 +893,7 @@ class JsonTestCase(ConnectingTestCase):
new, newa = psycopg2.extras.register_json( new, newa = psycopg2.extras.register_json(
loads=loads, oid=oid, array_oid=array_oid) loads=loads, oid=oid, array_oid=array_oid)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("""select '{"a": 100.0, "b": null}'::json""") curs.execute("""select '{"a": 100.0, "b": null}'::json""")
data = curs.fetchone()[0] data = curs.fetchone()[0]
self.assert_(isinstance(data['a'], Decimal)) self.assert_(isinstance(data['a'], Decimal))
@ -907,8 +901,7 @@ class JsonTestCase(ConnectingTestCase):
@skip_before_postgres(9, 2) @skip_before_postgres(9, 2)
def test_register_default(self): def test_register_default(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
def loads(s): def loads(s):
return psycopg2.extras.json.loads(s, parse_float=Decimal) return psycopg2.extras.json.loads(s, parse_float=Decimal)
psycopg2.extras.register_default_json(curs, loads=loads) psycopg2.extras.register_default_json(curs, loads=loads)
@ -926,14 +919,14 @@ class JsonTestCase(ConnectingTestCase):
@skip_if_no_json_type @skip_if_no_json_type
def test_null(self): def test_null(self):
psycopg2.extras.register_json(self.conn) psycopg2.extras.register_json(self.conn)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("""select NULL::json""") curs.execute("""select NULL::json""")
self.assertEqual(curs.fetchone()[0], None) self.assertEqual(curs.fetchone()[0], None)
curs.execute("""select NULL::json[]""") curs.execute("""select NULL::json[]""")
self.assertEqual(curs.fetchone()[0], None) self.assertEqual(curs.fetchone()[0], None)
def test_no_array_oid(self): def test_no_array_oid(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
t1, t2 = psycopg2.extras.register_json(curs, oid=25) t1, t2 = psycopg2.extras.register_json(curs, oid=25)
self.assertEqual(t1.values[0], 25) self.assertEqual(t1.values[0], 25)
self.assertEqual(t2, None) self.assertEqual(t2, None)
@ -956,13 +949,13 @@ class JsonTestCase(ConnectingTestCase):
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def test_scs(self): def test_scs(self):
cnn_on = self.connect(options="-c standard_conforming_strings=on") cnn_on = self.connect(options="-c standard_conforming_strings=on")
cur_on = cnn_on.cursor() with cnn_on.cursor() as cur_on:
self.assertEqual( self.assertEqual(
cur_on.mogrify("%s", [psycopg2.extras.Json({'a': '"'})]), cur_on.mogrify("%s", [psycopg2.extras.Json({'a': '"'})]),
b'\'{"a": "\\""}\'') b'\'{"a": "\\""}\'')
cnn_off = self.connect(options="-c standard_conforming_strings=off") cnn_off = self.connect(options="-c standard_conforming_strings=off")
cur_off = cnn_off.cursor() with cnn_off.cursor() as cur_off:
self.assertEqual( self.assertEqual(
cur_off.mogrify("%s", [psycopg2.extras.Json({'a': '"'})]), cur_off.mogrify("%s", [psycopg2.extras.Json({'a': '"'})]),
b'E\'{"a": "\\\\""}\'') b'E\'{"a": "\\\\""}\'')
@ -985,8 +978,7 @@ class JsonbTestCase(ConnectingTestCase):
return rv return rv
def test_default_cast(self): def test_default_cast(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""")
self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None})
@ -995,12 +987,12 @@ class JsonbTestCase(ConnectingTestCase):
def test_register_on_connection(self): def test_register_on_connection(self):
psycopg2.extras.register_json(self.conn, loads=self.myloads, name='jsonb') psycopg2.extras.register_json(self.conn, loads=self.myloads, name='jsonb')
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""")
self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1}) self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1})
def test_register_on_cursor(self): def test_register_on_cursor(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
psycopg2.extras.register_json(curs, loads=self.myloads, name='jsonb') psycopg2.extras.register_json(curs, loads=self.myloads, name='jsonb')
curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""")
self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1}) self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1})
@ -1009,7 +1001,7 @@ class JsonbTestCase(ConnectingTestCase):
def test_register_globally(self): def test_register_globally(self):
new, newa = psycopg2.extras.register_json(self.conn, new, newa = psycopg2.extras.register_json(self.conn,
loads=self.myloads, globally=True, name='jsonb') loads=self.myloads, globally=True, name='jsonb')
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""")
self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1}) self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1})
@ -1020,7 +1012,7 @@ class JsonbTestCase(ConnectingTestCase):
return json.loads(s, parse_float=Decimal) return json.loads(s, parse_float=Decimal)
psycopg2.extras.register_json(self.conn, loads=loads, name='jsonb') psycopg2.extras.register_json(self.conn, loads=loads, name='jsonb')
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""")
data = curs.fetchone()[0] data = curs.fetchone()[0]
self.assert_(isinstance(data['a'], Decimal)) self.assert_(isinstance(data['a'], Decimal))
@ -1032,8 +1024,7 @@ class JsonbTestCase(ConnectingTestCase):
self.assertEqual(data['a'], 100.0) self.assertEqual(data['a'], 100.0)
def test_register_default(self): def test_register_default(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
def loads(s): def loads(s):
return psycopg2.extras.json.loads(s, parse_float=Decimal) return psycopg2.extras.json.loads(s, parse_float=Decimal)
@ -1050,7 +1041,7 @@ class JsonbTestCase(ConnectingTestCase):
self.assertEqual(data[0]['a'], Decimal('100.0')) self.assertEqual(data[0]['a'], Decimal('100.0'))
def test_null(self): def test_null(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("""select NULL::jsonb""") curs.execute("""select NULL::jsonb""")
self.assertEqual(curs.fetchone()[0], None) self.assertEqual(curs.fetchone()[0], None)
curs.execute("""select NULL::jsonb[]""") curs.execute("""select NULL::jsonb[]""")
@ -1332,14 +1323,14 @@ class RangeCasterTestCase(ConnectingTestCase):
'daterange', 'tsrange', 'tstzrange') 'daterange', 'tsrange', 'tstzrange')
def test_cast_null(self): def test_cast_null(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
for type in self.builtin_ranges: for type in self.builtin_ranges:
cur.execute("select NULL::%s" % type) cur.execute("select NULL::%s" % type)
r = cur.fetchone()[0] r = cur.fetchone()[0]
self.assertEqual(r, None) self.assertEqual(r, None)
def test_cast_empty(self): def test_cast_empty(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
for type in self.builtin_ranges: for type in self.builtin_ranges:
cur.execute("select 'empty'::%s" % type) cur.execute("select 'empty'::%s" % type)
r = cur.fetchone()[0] r = cur.fetchone()[0]
@ -1347,7 +1338,7 @@ class RangeCasterTestCase(ConnectingTestCase):
self.assert_(r.isempty) self.assert_(r.isempty)
def test_cast_inf(self): def test_cast_inf(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
for type in self.builtin_ranges: for type in self.builtin_ranges:
cur.execute("select '(,)'::%s" % type) cur.execute("select '(,)'::%s" % type)
r = cur.fetchone()[0] r = cur.fetchone()[0]
@ -1357,7 +1348,7 @@ class RangeCasterTestCase(ConnectingTestCase):
self.assert_(r.upper_inf) self.assert_(r.upper_inf)
def test_cast_numbers(self): def test_cast_numbers(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
for type in ('int4range', 'int8range'): for type in ('int4range', 'int8range'):
cur.execute("select '(10,20)'::%s" % type) cur.execute("select '(10,20)'::%s" % type)
r = cur.fetchone()[0] r = cur.fetchone()[0]
@ -1382,7 +1373,7 @@ class RangeCasterTestCase(ConnectingTestCase):
self.assert_(not r.upper_inc) self.assert_(not r.upper_inc)
def test_cast_date(self): def test_cast_date(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select '(2000-01-01,2012-12-31)'::daterange") cur.execute("select '(2000-01-01,2012-12-31)'::daterange")
r = cur.fetchone()[0] r = cur.fetchone()[0]
self.assert_(isinstance(r, DateRange)) self.assert_(isinstance(r, DateRange))
@ -1395,7 +1386,7 @@ class RangeCasterTestCase(ConnectingTestCase):
self.assert_(not r.upper_inc) self.assert_(not r.upper_inc)
def test_cast_timestamp(self): def test_cast_timestamp(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
ts1 = datetime(2000, 1, 1) ts1 = datetime(2000, 1, 1)
ts2 = datetime(2000, 12, 31, 23, 59, 59, 999) ts2 = datetime(2000, 12, 31, 23, 59, 59, 999)
cur.execute("select tsrange(%s, %s, '()')", (ts1, ts2)) cur.execute("select tsrange(%s, %s, '()')", (ts1, ts2))
@ -1410,7 +1401,7 @@ class RangeCasterTestCase(ConnectingTestCase):
self.assert_(not r.upper_inc) self.assert_(not r.upper_inc)
def test_cast_timestamptz(self): def test_cast_timestamptz(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
ts1 = datetime(2000, 1, 1, tzinfo=FixedOffsetTimezone(600)) ts1 = datetime(2000, 1, 1, tzinfo=FixedOffsetTimezone(600))
ts2 = datetime(2000, 12, 31, 23, 59, 59, 999, ts2 = datetime(2000, 12, 31, 23, 59, 59, 999,
tzinfo=FixedOffsetTimezone(600)) tzinfo=FixedOffsetTimezone(600))
@ -1426,8 +1417,7 @@ class RangeCasterTestCase(ConnectingTestCase):
self.assert_(r.upper_inc) self.assert_(r.upper_inc)
def test_adapt_number_range(self): def test_adapt_number_range(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
r = NumericRange(empty=True) r = NumericRange(empty=True)
cur.execute("select %s::int4range", (r,)) cur.execute("select %s::int4range", (r,))
r1 = cur.fetchone()[0] r1 = cur.fetchone()[0]
@ -1453,8 +1443,7 @@ class RangeCasterTestCase(ConnectingTestCase):
self.assert_(r1.upper_inc) self.assert_(r1.upper_inc)
def test_adapt_numeric_range(self): def test_adapt_numeric_range(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
r = NumericRange(empty=True) r = NumericRange(empty=True)
cur.execute("select %s::int4range", (r,)) cur.execute("select %s::int4range", (r,))
r1 = cur.fetchone()[0] r1 = cur.fetchone()[0]
@ -1480,8 +1469,7 @@ class RangeCasterTestCase(ConnectingTestCase):
self.assert_(r1.upper_inc) self.assert_(r1.upper_inc)
def test_adapt_date_range(self): def test_adapt_date_range(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
d1 = date(2012, 1, 1) d1 = date(2012, 1, 1)
d2 = date(2012, 12, 31) d2 = date(2012, 12, 31)
r = DateRange(d1, d2) r = DateRange(d1, d2)
@ -1513,7 +1501,7 @@ class RangeCasterTestCase(ConnectingTestCase):
@restore_types @restore_types
def test_register_range_adapter(self): def test_register_range_adapter(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("create type textrange as range (subtype=text)") cur.execute("create type textrange as range (subtype=text)")
rc = register_range('textrange', 'TextRange', cur) rc = register_range('textrange', 'TextRange', cur)
@ -1539,7 +1527,7 @@ class RangeCasterTestCase(ConnectingTestCase):
self.assert_(r1.upper_inc) self.assert_(r1.upper_inc)
def test_range_escaping(self): def test_range_escaping(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("create type textrange as range (subtype=text)") cur.execute("create type textrange as range (subtype=text)")
rc = register_range('textrange', 'TextRange', cur) rc = register_range('textrange', 'TextRange', cur)
@ -1592,13 +1580,13 @@ class RangeCasterTestCase(ConnectingTestCase):
del ext.adapters[TextRange, ext.ISQLQuote] del ext.adapters[TextRange, ext.ISQLQuote]
def test_range_not_found(self): def test_range_not_found(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.assertRaises(psycopg2.ProgrammingError, self.assertRaises(psycopg2.ProgrammingError,
register_range, 'nosuchrange', 'FailRange', cur) register_range, 'nosuchrange', 'FailRange', cur)
@restore_types @restore_types
def test_schema_range(self): def test_schema_range(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("create schema rs") cur.execute("create schema rs")
cur.execute("create type r1 as range (subtype=text)") cur.execute("create type r1 as range (subtype=text)")
cur.execute("create type r2 as range (subtype=text)") cur.execute("create type r2 as range (subtype=text)")

View File

@ -33,7 +33,7 @@ from .testutils import ConnectingTestCase, skip_before_postgres
class WithTestCase(ConnectingTestCase): class WithTestCase(ConnectingTestCase):
def setUp(self): def setUp(self):
ConnectingTestCase.setUp(self) ConnectingTestCase.setUp(self)
curs = self.conn.cursor() with self.conn.cursor() as curs:
try: try:
curs.execute("delete from test_with") curs.execute("delete from test_with")
self.conn.commit() self.conn.commit()
@ -49,49 +49,49 @@ class WithConnectionTestCase(WithTestCase):
with self.conn as conn: with self.conn as conn:
self.assert_(self.conn is conn) self.assert_(self.conn is conn)
self.assertEqual(conn.status, ext.STATUS_READY) self.assertEqual(conn.status, ext.STATUS_READY)
curs = conn.cursor() with conn.cursor() as curs:
curs.execute("insert into test_with values (1)") curs.execute("insert into test_with values (1)")
self.assertEqual(conn.status, ext.STATUS_BEGIN) self.assertEqual(conn.status, ext.STATUS_BEGIN)
self.assertEqual(self.conn.status, ext.STATUS_READY) self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assert_(not self.conn.closed) self.assert_(not self.conn.closed)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from test_with") curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [(1,)]) self.assertEqual(curs.fetchall(), [(1,)])
def test_with_connect_idiom(self): def test_with_connect_idiom(self):
with self.connect() as conn: with self.connect() as conn:
self.assertEqual(conn.status, ext.STATUS_READY) self.assertEqual(conn.status, ext.STATUS_READY)
curs = conn.cursor() with conn.cursor() as curs:
curs.execute("insert into test_with values (2)") curs.execute("insert into test_with values (2)")
self.assertEqual(conn.status, ext.STATUS_BEGIN) self.assertEqual(conn.status, ext.STATUS_BEGIN)
self.assertEqual(self.conn.status, ext.STATUS_READY) self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assert_(not self.conn.closed) self.assert_(not self.conn.closed)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from test_with") curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [(2,)]) self.assertEqual(curs.fetchall(), [(2,)])
def test_with_error_db(self): def test_with_error_db(self):
def f(): def f():
with self.conn as conn: with self.conn as conn:
curs = conn.cursor() with conn.cursor() as curs:
curs.execute("insert into test_with values ('a')") curs.execute("insert into test_with values ('a')")
self.assertRaises(psycopg2.DataError, f) self.assertRaises(psycopg2.DataError, f)
self.assertEqual(self.conn.status, ext.STATUS_READY) self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assert_(not self.conn.closed) self.assert_(not self.conn.closed)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from test_with") curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), []) self.assertEqual(curs.fetchall(), [])
def test_with_error_python(self): def test_with_error_python(self):
def f(): def f():
with self.conn as conn: with self.conn as conn:
curs = conn.cursor() with conn.cursor() as curs:
curs.execute("insert into test_with values (3)") curs.execute("insert into test_with values (3)")
1 / 0 1 / 0
@ -99,7 +99,7 @@ class WithConnectionTestCase(WithTestCase):
self.assertEqual(self.conn.status, ext.STATUS_READY) self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assert_(not self.conn.closed) self.assert_(not self.conn.closed)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from test_with") curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), []) self.assertEqual(curs.fetchall(), [])
@ -120,13 +120,13 @@ class WithConnectionTestCase(WithTestCase):
super(MyConn, self).commit() super(MyConn, self).commit()
with self.connect(connection_factory=MyConn) as conn: with self.connect(connection_factory=MyConn) as conn:
curs = conn.cursor() with conn.cursor() as curs:
curs.execute("insert into test_with values (10)") curs.execute("insert into test_with values (10)")
self.assertEqual(conn.status, ext.STATUS_READY) self.assertEqual(conn.status, ext.STATUS_READY)
self.assert_(commits) self.assert_(commits)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from test_with") curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [(10,)]) self.assertEqual(curs.fetchall(), [(10,)])
@ -140,7 +140,7 @@ class WithConnectionTestCase(WithTestCase):
try: try:
with self.connect(connection_factory=MyConn) as conn: with self.connect(connection_factory=MyConn) as conn:
curs = conn.cursor() with conn.cursor() as curs:
curs.execute("insert into test_with values (11)") curs.execute("insert into test_with values (11)")
1 / 0 1 / 0
except ZeroDivisionError: except ZeroDivisionError:
@ -151,7 +151,7 @@ class WithConnectionTestCase(WithTestCase):
self.assertEqual(conn.status, ext.STATUS_READY) self.assertEqual(conn.status, ext.STATUS_READY)
self.assert_(rollbacks) self.assert_(rollbacks)
curs = conn.cursor() with conn.cursor() as curs:
curs.execute("select * from test_with") curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), []) self.assertEqual(curs.fetchall(), [])
@ -168,7 +168,7 @@ class WithCursorTestCase(WithTestCase):
self.assertEqual(self.conn.status, ext.STATUS_READY) self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assert_(not self.conn.closed) self.assert_(not self.conn.closed)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from test_with") curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [(4,)]) self.assertEqual(curs.fetchall(), [(4,)])
@ -185,7 +185,7 @@ class WithCursorTestCase(WithTestCase):
self.assert_(not self.conn.closed) self.assert_(not self.conn.closed)
self.assert_(curs.closed) self.assert_(curs.closed)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from test_with") curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), []) self.assertEqual(curs.fetchall(), [])

View File

@ -228,7 +228,7 @@ def skip_if_no_uuid(f):
@wraps(f) @wraps(f)
def skip_if_no_uuid_(self): def skip_if_no_uuid_(self):
try: try:
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select typname from pg_type where typname = 'uuid'") cur.execute("select typname from pg_type where typname = 'uuid'")
has = cur.fetchone() has = cur.fetchone()
finally: finally:
@ -249,7 +249,7 @@ def skip_if_tpc_disabled(f):
def skip_if_tpc_disabled_(self): def skip_if_tpc_disabled_(self):
cnn = self.connect() cnn = self.connect()
try: try:
cur = cnn.cursor() with cnn.cursor() as cur:
try: try:
cur.execute("SHOW max_prepared_transactions;") cur.execute("SHOW max_prepared_transactions;")
except psycopg2.ProgrammingError: except psycopg2.ProgrammingError: