diff --git a/lib/_json.py b/lib/_json.py index eac37972..9dcc227c 100644 --- a/lib/_json.py +++ b/lib/_json.py @@ -180,23 +180,22 @@ def _get_json_oids(conn_or_curs, name='json'): from psycopg2.extensions import STATUS_IN_TRANSACTION from psycopg2.extras import _solve_conn_curs - conn, curs = _solve_conn_curs(conn_or_curs) + with _solve_conn_curs(conn_or_curs) as (conn, curs): + # Store the transaction status of the connection to revert it after use + conn_status = conn.status - # Store the transaction status of the connection to revert it after use - conn_status = conn.status + # column typarray not available before PG 8.3 + typarray = conn.info.server_version >= 80300 and "typarray" or "NULL" - # column typarray not available before PG 8.3 - typarray = conn.info.server_version >= 80300 and "typarray" or "NULL" + # get the oid for the hstore + curs.execute( + "SELECT t.oid, %s FROM pg_type t WHERE t.typname = %%s;" + % typarray, (name,)) + r = curs.fetchone() - # get the oid for the hstore - curs.execute( - "SELECT t.oid, %s FROM pg_type t WHERE t.typname = %%s;" - % typarray, (name,)) - r = curs.fetchone() - - # revert the status of the connection as before the command - if conn_status != STATUS_IN_TRANSACTION and not conn.autocommit: - conn.rollback() + # revert the status of the connection as before the command + if conn_status != STATUS_IN_TRANSACTION and not conn.autocommit: + conn.rollback() if not r: raise conn.ProgrammingError("%s data type not found" % name) diff --git a/lib/_range.py b/lib/_range.py index b668fb63..607d85b0 100644 --- a/lib/_range.py +++ b/lib/_range.py @@ -351,25 +351,24 @@ class RangeCaster(object): """ from psycopg2.extensions import STATUS_IN_TRANSACTION from psycopg2.extras import _solve_conn_curs - conn, curs = _solve_conn_curs(conn_or_curs) + with _solve_conn_curs(conn_or_curs) as (conn, curs): + if conn.info.server_version < 90200: + raise ProgrammingError("range types not available in version %s" + % conn.info.server_version) - if conn.info.server_version < 90200: - raise ProgrammingError("range types not available in version %s" - % conn.info.server_version) + # Store the transaction status of the connection to revert it after use + conn_status = conn.status - # Store the transaction status of the connection to revert it after use - conn_status = conn.status + # Use the correct schema + if '.' in name: + schema, tname = name.split('.', 1) + else: + tname = name + schema = 'public' - # Use the correct schema - if '.' in name: - schema, tname = name.split('.', 1) - else: - tname = name - schema = 'public' - - # get the type oid and attributes - try: - curs.execute("""\ + # get the type oid and attributes + try: + curs.execute("""\ select rngtypid, rngsubtype, (select typarray from pg_type where oid = rngtypid) from pg_range r @@ -378,17 +377,17 @@ join pg_namespace ns on ns.oid = typnamespace where typname = %s and ns.nspname = %s; """, (tname, schema)) - except ProgrammingError: - if not conn.autocommit: - conn.rollback() - raise - else: - rec = curs.fetchone() + except ProgrammingError: + if not conn.autocommit: + conn.rollback() + raise + else: + rec = curs.fetchone() - # revert the status of the connection as before the command - if (conn_status != STATUS_IN_TRANSACTION - and not conn.autocommit): - conn.rollback() + # revert the status of the connection as before the command + if (conn_status != STATUS_IN_TRANSACTION + and not conn.autocommit): + conn.rollback() if not rec: raise ProgrammingError( diff --git a/lib/extras.py b/lib/extras.py index 236a5b7e..6c378bb2 100644 --- a/lib/extras.py +++ b/lib/extras.py @@ -26,6 +26,7 @@ and classes until a better place in the distribution is found. # FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public # License for more details. +import contextlib import os as _os import time as _time import re as _re @@ -781,6 +782,7 @@ def wait_select(conn): continue +@contextlib.contextmanager def _solve_conn_curs(conn_or_curs): """Return the connection and a DBAPI cursor from a connection or cursor.""" if conn_or_curs is None: @@ -793,7 +795,8 @@ def _solve_conn_curs(conn_or_curs): conn = conn_or_curs curs = conn.cursor(cursor_factory=_cursor) - return conn, curs + with curs: + yield conn, curs class HstoreAdapter(object): @@ -904,31 +907,30 @@ class HstoreAdapter(object): def get_oids(self, conn_or_curs): """Return the lists of OID of the hstore and hstore[] types. """ - conn, curs = _solve_conn_curs(conn_or_curs) + with _solve_conn_curs(conn_or_curs) as (conn, curs): + # Store the transaction status of the connection to revert it after use + conn_status = conn.status - # Store the transaction status of the connection to revert it after use - conn_status = conn.status + # column typarray not available before PG 8.3 + typarray = conn.info.server_version >= 80300 and "typarray" or "NULL" - # column typarray not available before PG 8.3 - typarray = conn.info.server_version >= 80300 and "typarray" or "NULL" + rv0, rv1 = [], [] - rv0, rv1 = [], [] - - # get the oid for the hstore - curs.execute("""\ + # get the oid for the hstore + curs.execute("""\ SELECT t.oid, %s FROM pg_type t JOIN pg_namespace ns ON typnamespace = ns.oid WHERE typname = 'hstore'; """ % typarray) - for oids in curs: - rv0.append(oids[0]) - rv1.append(oids[1]) + for oids in curs: + rv0.append(oids[0]) + rv1.append(oids[1]) - # revert the status of the connection as before the command - if (conn_status != _ext.STATUS_IN_TRANSACTION - and not conn.autocommit): - conn.rollback() + # revert the status of the connection as before the command + if (conn_status != _ext.STATUS_IN_TRANSACTION + and not conn.autocommit): + conn.rollback() return tuple(rv0), tuple(rv1) @@ -1083,23 +1085,22 @@ class CompositeCaster(object): Raise `ProgrammingError` if the type is not found. """ - conn, curs = _solve_conn_curs(conn_or_curs) + with _solve_conn_curs(conn_or_curs) as (conn, curs): + # Store the transaction status of the connection to revert it after use + conn_status = conn.status - # Store the transaction status of the connection to revert it after use - conn_status = conn.status + # Use the correct schema + if '.' in name: + schema, tname = name.split('.', 1) + else: + tname = name + schema = 'public' - # Use the correct schema - if '.' in name: - schema, tname = name.split('.', 1) - else: - tname = name - schema = 'public' + # column typarray not available before PG 8.3 + typarray = conn.info.server_version >= 80300 and "typarray" or "NULL" - # column typarray not available before PG 8.3 - typarray = conn.info.server_version >= 80300 and "typarray" or "NULL" - - # get the type oid and attributes - curs.execute("""\ + # get the type oid and attributes + curs.execute("""\ SELECT t.oid, %s, attname, atttypid FROM pg_type t JOIN pg_namespace ns ON typnamespace = ns.oid @@ -1109,12 +1110,12 @@ WHERE typname = %%s AND nspname = %%s ORDER BY attnum; """ % typarray, (tname, schema)) - recs = curs.fetchall() + recs = curs.fetchall() - # revert the status of the connection as before the command - if (conn_status != _ext.STATUS_IN_TRANSACTION - and not conn.autocommit): - conn.rollback() + # revert the status of the connection as before the command + if (conn_status != _ext.STATUS_IN_TRANSACTION + and not conn.autocommit): + conn.rollback() if not recs: raise psycopg2.ProgrammingError( diff --git a/tests/test_types_extras.py b/tests/test_types_extras.py index 91e4a8ea..e4828077 100755 --- a/tests/test_types_extras.py +++ b/tests/test_types_extras.py @@ -36,7 +36,7 @@ from psycopg2._json import _get_json_oids from psycopg2.extras import ( CompositeCaster, DateRange, DateTimeRange, DateTimeTZRange, HstoreAdapter, Inet, Json, NumericRange, Range, RealDictConnection, - register_composite, register_hstore, register_range, + register_composite, register_hstore, register_range, _solve_conn_curs ) from psycopg2.tz import FixedOffsetTimezone @@ -1624,6 +1624,32 @@ class RangeCasterTestCase(ConnectingTestCase): cur.execute("rollback to savepoint x;") +class TestSolveConnCurs(ConnectingTestCase): + def test_pass_connection(self): + with _solve_conn_curs(self.conn) as (conn, curs): + self.assertIsInstance(conn, psycopg2.extensions.connection) + self.assertIsInstance(curs, psycopg2.extensions.cursor) + self.assertIs(conn, self.conn) + self.assertFalse(conn.closed) + self.assertFalse(curs.closed) + self.assertFalse(conn.closed) + self.assertTrue(curs.closed) + + def test_pass_cursor(self): + with self.conn.cursor() as cursor: + with _solve_conn_curs(cursor) as (conn, curs): + self.assertIsInstance(conn, psycopg2.extensions.connection) + self.assertIsInstance(curs, psycopg2.extensions.cursor) + self.assertIs(conn, self.conn) + self.assertIsNot(curs, cursor) + self.assertFalse(conn.closed) + self.assertFalse(curs.closed) + self.assertFalse(conn.closed) + self.assertTrue(curs.closed) + self.assertFalse(conn.closed) + self.assertTrue(curs.closed) + + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__)