This commit is contained in:
Jon Dufresne 2021-04-22 18:28:47 -04:00 committed by GitHub
commit 38b0e253d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 102 additions and 77 deletions

View File

@ -180,8 +180,7 @@ def _get_json_oids(conn_or_curs, name='json'):
from psycopg2.extensions import STATUS_IN_TRANSACTION
from psycopg2.extras import _solve_conn_curs
conn, curs = _solve_conn_curs(conn_or_curs)
with _solve_conn_curs(conn_or_curs) as (conn, curs):
# Store the transaction status of the connection to revert it after use
conn_status = conn.status

View File

@ -351,8 +351,7 @@ class RangeCaster(object):
"""
from psycopg2.extensions import STATUS_IN_TRANSACTION
from psycopg2.extras import _solve_conn_curs
conn, curs = _solve_conn_curs(conn_or_curs)
with _solve_conn_curs(conn_or_curs) as (conn, curs):
if conn.info.server_version < 90200:
raise ProgrammingError("range types not available in version %s"
% conn.info.server_version)

View File

@ -26,6 +26,7 @@ and classes until a better place in the distribution is found.
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import contextlib
import os as _os
import time as _time
import re as _re
@ -790,6 +791,7 @@ def wait_select(conn):
continue
@contextlib.contextmanager
def _solve_conn_curs(conn_or_curs):
"""Return the connection and a DBAPI cursor from a connection or cursor."""
if conn_or_curs is None:
@ -802,7 +804,8 @@ def _solve_conn_curs(conn_or_curs):
conn = conn_or_curs
curs = conn.cursor(cursor_factory=_cursor)
return conn, curs
with curs:
yield conn, curs
class HstoreAdapter(object):
@ -913,8 +916,7 @@ class HstoreAdapter(object):
def get_oids(self, conn_or_curs):
"""Return the lists of OID of the hstore and hstore[] types.
"""
conn, curs = _solve_conn_curs(conn_or_curs)
with _solve_conn_curs(conn_or_curs) as (conn, curs):
# Store the transaction status of the connection to revert it after use
conn_status = conn.status
@ -1092,8 +1094,7 @@ class CompositeCaster(object):
Raise `ProgrammingError` if the type is not found.
"""
conn, curs = _solve_conn_curs(conn_or_curs)
with _solve_conn_curs(conn_or_curs) as (conn, curs):
# Store the transaction status of the connection to revert it after use
conn_status = conn.status

View File

@ -36,7 +36,7 @@ from psycopg2._json import _get_json_oids
from psycopg2.extras import (
CompositeCaster, DateRange, DateTimeRange, DateTimeTZRange, HstoreAdapter,
Inet, Json, NumericRange, Range, RealDictConnection,
register_composite, register_hstore, register_range,
register_composite, register_hstore, register_range, _solve_conn_curs
)
from psycopg2.tz import FixedOffsetTimezone
@ -1632,6 +1632,32 @@ class RangeCasterTestCase(ConnectingTestCase):
cur.execute("rollback to savepoint x;")
class TestSolveConnCurs(ConnectingTestCase):
def test_pass_connection(self):
with _solve_conn_curs(self.conn) as (conn, curs):
self.assertIsInstance(conn, psycopg2.extensions.connection)
self.assertIsInstance(curs, psycopg2.extensions.cursor)
self.assertIs(conn, self.conn)
self.assertFalse(conn.closed)
self.assertFalse(curs.closed)
self.assertFalse(conn.closed)
self.assertTrue(curs.closed)
def test_pass_cursor(self):
with self.conn.cursor() as cursor:
with _solve_conn_curs(cursor) as (conn, curs):
self.assertIsInstance(conn, psycopg2.extensions.connection)
self.assertIsInstance(curs, psycopg2.extensions.cursor)
self.assertIs(conn, self.conn)
self.assertIsNot(curs, cursor)
self.assertFalse(conn.closed)
self.assertFalse(curs.closed)
self.assertFalse(conn.closed)
self.assertTrue(curs.closed)
self.assertFalse(conn.closed)
self.assertTrue(curs.closed)
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)