Improve resource ownership semantics of _solve_conn_curs()

A new cursor is always creates so always close it once finished.
This commit is contained in:
Jon Dufresne 2020-02-04 05:07:58 -08:00
parent 9bcca1a7b0
commit e9ae67ff07
4 changed files with 102 additions and 77 deletions

View File

@ -180,8 +180,7 @@ def _get_json_oids(conn_or_curs, name='json'):
from psycopg2.extensions import STATUS_IN_TRANSACTION
from psycopg2.extras import _solve_conn_curs
conn, curs = _solve_conn_curs(conn_or_curs)
with _solve_conn_curs(conn_or_curs) as (conn, curs):
# Store the transaction status of the connection to revert it after use
conn_status = conn.status

View File

@ -351,8 +351,7 @@ class RangeCaster(object):
"""
from psycopg2.extensions import STATUS_IN_TRANSACTION
from psycopg2.extras import _solve_conn_curs
conn, curs = _solve_conn_curs(conn_or_curs)
with _solve_conn_curs(conn_or_curs) as (conn, curs):
if conn.info.server_version < 90200:
raise ProgrammingError("range types not available in version %s"
% conn.info.server_version)

View File

@ -26,6 +26,7 @@ and classes until a better place in the distribution is found.
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import contextlib
import os as _os
import time as _time
import re as _re
@ -781,6 +782,7 @@ def wait_select(conn):
continue
@contextlib.contextmanager
def _solve_conn_curs(conn_or_curs):
"""Return the connection and a DBAPI cursor from a connection or cursor."""
if conn_or_curs is None:
@ -793,7 +795,8 @@ def _solve_conn_curs(conn_or_curs):
conn = conn_or_curs
curs = conn.cursor(cursor_factory=_cursor)
return conn, curs
with curs:
yield conn, curs
class HstoreAdapter(object):
@ -904,8 +907,7 @@ class HstoreAdapter(object):
def get_oids(self, conn_or_curs):
"""Return the lists of OID of the hstore and hstore[] types.
"""
conn, curs = _solve_conn_curs(conn_or_curs)
with _solve_conn_curs(conn_or_curs) as (conn, curs):
# Store the transaction status of the connection to revert it after use
conn_status = conn.status
@ -1083,8 +1085,7 @@ class CompositeCaster(object):
Raise `ProgrammingError` if the type is not found.
"""
conn, curs = _solve_conn_curs(conn_or_curs)
with _solve_conn_curs(conn_or_curs) as (conn, curs):
# Store the transaction status of the connection to revert it after use
conn_status = conn.status

View File

@ -36,7 +36,7 @@ from psycopg2._json import _get_json_oids
from psycopg2.extras import (
CompositeCaster, DateRange, DateTimeRange, DateTimeTZRange, HstoreAdapter,
Inet, Json, NumericRange, Range, RealDictConnection,
register_composite, register_hstore, register_range,
register_composite, register_hstore, register_range, _solve_conn_curs
)
from psycopg2.tz import FixedOffsetTimezone
@ -1624,6 +1624,32 @@ class RangeCasterTestCase(ConnectingTestCase):
cur.execute("rollback to savepoint x;")
class TestSolveConnCurs(ConnectingTestCase):
def test_pass_connection(self):
with _solve_conn_curs(self.conn) as (conn, curs):
self.assertIsInstance(conn, psycopg2.extensions.connection)
self.assertIsInstance(curs, psycopg2.extensions.cursor)
self.assertIs(conn, self.conn)
self.assertFalse(conn.closed)
self.assertFalse(curs.closed)
self.assertFalse(conn.closed)
self.assertTrue(curs.closed)
def test_pass_cursor(self):
with self.conn.cursor() as cursor:
with _solve_conn_curs(cursor) as (conn, curs):
self.assertIsInstance(conn, psycopg2.extensions.connection)
self.assertIsInstance(curs, psycopg2.extensions.cursor)
self.assertIs(conn, self.conn)
self.assertIsNot(curs, cursor)
self.assertFalse(conn.closed)
self.assertFalse(curs.closed)
self.assertFalse(conn.closed)
self.assertTrue(curs.closed)
self.assertFalse(conn.closed)
self.assertTrue(curs.closed)
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)