This commit is contained in:
Jon Dufresne 2021-04-22 18:28:47 -04:00 committed by GitHub
commit 38b0e253d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 102 additions and 77 deletions

View File

@ -180,23 +180,22 @@ def _get_json_oids(conn_or_curs, name='json'):
from psycopg2.extensions import STATUS_IN_TRANSACTION from psycopg2.extensions import STATUS_IN_TRANSACTION
from psycopg2.extras import _solve_conn_curs from psycopg2.extras import _solve_conn_curs
conn, curs = _solve_conn_curs(conn_or_curs) with _solve_conn_curs(conn_or_curs) as (conn, curs):
# Store the transaction status of the connection to revert it after use
conn_status = conn.status
# Store the transaction status of the connection to revert it after use # column typarray not available before PG 8.3
conn_status = conn.status typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
# column typarray not available before PG 8.3 # get the oid for the hstore
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL" curs.execute(
"SELECT t.oid, %s FROM pg_type t WHERE t.typname = %%s;"
% typarray, (name,))
r = curs.fetchone()
# get the oid for the hstore # revert the status of the connection as before the command
curs.execute( if conn_status != STATUS_IN_TRANSACTION and not conn.autocommit:
"SELECT t.oid, %s FROM pg_type t WHERE t.typname = %%s;" conn.rollback()
% typarray, (name,))
r = curs.fetchone()
# revert the status of the connection as before the command
if conn_status != STATUS_IN_TRANSACTION and not conn.autocommit:
conn.rollback()
if not r: if not r:
raise conn.ProgrammingError("%s data type not found" % name) raise conn.ProgrammingError("%s data type not found" % name)

View File

@ -351,25 +351,24 @@ class RangeCaster(object):
""" """
from psycopg2.extensions import STATUS_IN_TRANSACTION from psycopg2.extensions import STATUS_IN_TRANSACTION
from psycopg2.extras import _solve_conn_curs from psycopg2.extras import _solve_conn_curs
conn, curs = _solve_conn_curs(conn_or_curs) with _solve_conn_curs(conn_or_curs) as (conn, curs):
if conn.info.server_version < 90200:
raise ProgrammingError("range types not available in version %s"
% conn.info.server_version)
if conn.info.server_version < 90200: # Store the transaction status of the connection to revert it after use
raise ProgrammingError("range types not available in version %s" conn_status = conn.status
% conn.info.server_version)
# Store the transaction status of the connection to revert it after use # Use the correct schema
conn_status = conn.status if '.' in name:
schema, tname = name.split('.', 1)
else:
tname = name
schema = 'public'
# Use the correct schema # get the type oid and attributes
if '.' in name: try:
schema, tname = name.split('.', 1) curs.execute("""\
else:
tname = name
schema = 'public'
# get the type oid and attributes
try:
curs.execute("""\
select rngtypid, rngsubtype, select rngtypid, rngsubtype,
(select typarray from pg_type where oid = rngtypid) (select typarray from pg_type where oid = rngtypid)
from pg_range r from pg_range r
@ -378,17 +377,17 @@ join pg_namespace ns on ns.oid = typnamespace
where typname = %s and ns.nspname = %s; where typname = %s and ns.nspname = %s;
""", (tname, schema)) """, (tname, schema))
except ProgrammingError: except ProgrammingError:
if not conn.autocommit: if not conn.autocommit:
conn.rollback() conn.rollback()
raise raise
else: else:
rec = curs.fetchone() rec = curs.fetchone()
# revert the status of the connection as before the command # revert the status of the connection as before the command
if (conn_status != STATUS_IN_TRANSACTION if (conn_status != STATUS_IN_TRANSACTION
and not conn.autocommit): and not conn.autocommit):
conn.rollback() conn.rollback()
if not rec: if not rec:
raise ProgrammingError( raise ProgrammingError(

View File

@ -26,6 +26,7 @@ and classes until a better place in the distribution is found.
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public # FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details. # License for more details.
import contextlib
import os as _os import os as _os
import time as _time import time as _time
import re as _re import re as _re
@ -790,6 +791,7 @@ def wait_select(conn):
continue continue
@contextlib.contextmanager
def _solve_conn_curs(conn_or_curs): def _solve_conn_curs(conn_or_curs):
"""Return the connection and a DBAPI cursor from a connection or cursor.""" """Return the connection and a DBAPI cursor from a connection or cursor."""
if conn_or_curs is None: if conn_or_curs is None:
@ -802,7 +804,8 @@ def _solve_conn_curs(conn_or_curs):
conn = conn_or_curs conn = conn_or_curs
curs = conn.cursor(cursor_factory=_cursor) curs = conn.cursor(cursor_factory=_cursor)
return conn, curs with curs:
yield conn, curs
class HstoreAdapter(object): class HstoreAdapter(object):
@ -913,31 +916,30 @@ class HstoreAdapter(object):
def get_oids(self, conn_or_curs): def get_oids(self, conn_or_curs):
"""Return the lists of OID of the hstore and hstore[] types. """Return the lists of OID of the hstore and hstore[] types.
""" """
conn, curs = _solve_conn_curs(conn_or_curs) with _solve_conn_curs(conn_or_curs) as (conn, curs):
# Store the transaction status of the connection to revert it after use
conn_status = conn.status
# Store the transaction status of the connection to revert it after use # column typarray not available before PG 8.3
conn_status = conn.status typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
# column typarray not available before PG 8.3 rv0, rv1 = [], []
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
rv0, rv1 = [], [] # get the oid for the hstore
curs.execute("""\
# get the oid for the hstore
curs.execute("""\
SELECT t.oid, %s SELECT t.oid, %s
FROM pg_type t JOIN pg_namespace ns FROM pg_type t JOIN pg_namespace ns
ON typnamespace = ns.oid ON typnamespace = ns.oid
WHERE typname = 'hstore'; WHERE typname = 'hstore';
""" % typarray) """ % typarray)
for oids in curs: for oids in curs:
rv0.append(oids[0]) rv0.append(oids[0])
rv1.append(oids[1]) rv1.append(oids[1])
# revert the status of the connection as before the command # revert the status of the connection as before the command
if (conn_status != _ext.STATUS_IN_TRANSACTION if (conn_status != _ext.STATUS_IN_TRANSACTION
and not conn.autocommit): and not conn.autocommit):
conn.rollback() conn.rollback()
return tuple(rv0), tuple(rv1) return tuple(rv0), tuple(rv1)
@ -1092,23 +1094,22 @@ class CompositeCaster(object):
Raise `ProgrammingError` if the type is not found. Raise `ProgrammingError` if the type is not found.
""" """
conn, curs = _solve_conn_curs(conn_or_curs) with _solve_conn_curs(conn_or_curs) as (conn, curs):
# Store the transaction status of the connection to revert it after use
conn_status = conn.status
# Store the transaction status of the connection to revert it after use # Use the correct schema
conn_status = conn.status if '.' in name:
schema, tname = name.split('.', 1)
else:
tname = name
schema = 'public'
# Use the correct schema # column typarray not available before PG 8.3
if '.' in name: typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
schema, tname = name.split('.', 1)
else:
tname = name
schema = 'public'
# column typarray not available before PG 8.3 # get the type oid and attributes
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL" curs.execute("""\
# get the type oid and attributes
curs.execute("""\
SELECT t.oid, %s, attname, atttypid SELECT t.oid, %s, attname, atttypid
FROM pg_type t FROM pg_type t
JOIN pg_namespace ns ON typnamespace = ns.oid JOIN pg_namespace ns ON typnamespace = ns.oid
@ -1118,12 +1119,12 @@ WHERE typname = %%s AND nspname = %%s
ORDER BY attnum; ORDER BY attnum;
""" % typarray, (tname, schema)) """ % typarray, (tname, schema))
recs = curs.fetchall() recs = curs.fetchall()
# revert the status of the connection as before the command # revert the status of the connection as before the command
if (conn_status != _ext.STATUS_IN_TRANSACTION if (conn_status != _ext.STATUS_IN_TRANSACTION
and not conn.autocommit): and not conn.autocommit):
conn.rollback() conn.rollback()
if not recs: if not recs:
raise psycopg2.ProgrammingError( raise psycopg2.ProgrammingError(

View File

@ -36,7 +36,7 @@ from psycopg2._json import _get_json_oids
from psycopg2.extras import ( from psycopg2.extras import (
CompositeCaster, DateRange, DateTimeRange, DateTimeTZRange, HstoreAdapter, CompositeCaster, DateRange, DateTimeRange, DateTimeTZRange, HstoreAdapter,
Inet, Json, NumericRange, Range, RealDictConnection, Inet, Json, NumericRange, Range, RealDictConnection,
register_composite, register_hstore, register_range, register_composite, register_hstore, register_range, _solve_conn_curs
) )
from psycopg2.tz import FixedOffsetTimezone from psycopg2.tz import FixedOffsetTimezone
@ -1632,6 +1632,32 @@ class RangeCasterTestCase(ConnectingTestCase):
cur.execute("rollback to savepoint x;") cur.execute("rollback to savepoint x;")
class TestSolveConnCurs(ConnectingTestCase):
def test_pass_connection(self):
with _solve_conn_curs(self.conn) as (conn, curs):
self.assertIsInstance(conn, psycopg2.extensions.connection)
self.assertIsInstance(curs, psycopg2.extensions.cursor)
self.assertIs(conn, self.conn)
self.assertFalse(conn.closed)
self.assertFalse(curs.closed)
self.assertFalse(conn.closed)
self.assertTrue(curs.closed)
def test_pass_cursor(self):
with self.conn.cursor() as cursor:
with _solve_conn_curs(cursor) as (conn, curs):
self.assertIsInstance(conn, psycopg2.extensions.connection)
self.assertIsInstance(curs, psycopg2.extensions.cursor)
self.assertIs(conn, self.conn)
self.assertIsNot(curs, cursor)
self.assertFalse(conn.closed)
self.assertFalse(curs.closed)
self.assertFalse(conn.closed)
self.assertTrue(curs.closed)
self.assertFalse(conn.closed)
self.assertTrue(curs.closed)
def test_suite(): def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__) return unittest.TestLoader().loadTestsFromName(__name__)