This commit is contained in:
Jon Dufresne 2021-04-22 18:28:47 -04:00 committed by GitHub
commit 38b0e253d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 102 additions and 77 deletions

View File

@ -180,23 +180,22 @@ def _get_json_oids(conn_or_curs, name='json'):
from psycopg2.extensions import STATUS_IN_TRANSACTION
from psycopg2.extras import _solve_conn_curs
conn, curs = _solve_conn_curs(conn_or_curs)
with _solve_conn_curs(conn_or_curs) as (conn, curs):
# Store the transaction status of the connection to revert it after use
conn_status = conn.status
# Store the transaction status of the connection to revert it after use
conn_status = conn.status
# column typarray not available before PG 8.3
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
# column typarray not available before PG 8.3
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
# get the oid for the hstore
curs.execute(
"SELECT t.oid, %s FROM pg_type t WHERE t.typname = %%s;"
% typarray, (name,))
r = curs.fetchone()
# get the oid for the hstore
curs.execute(
"SELECT t.oid, %s FROM pg_type t WHERE t.typname = %%s;"
% typarray, (name,))
r = curs.fetchone()
# revert the status of the connection as before the command
if conn_status != STATUS_IN_TRANSACTION and not conn.autocommit:
conn.rollback()
# revert the status of the connection as before the command
if conn_status != STATUS_IN_TRANSACTION and not conn.autocommit:
conn.rollback()
if not r:
raise conn.ProgrammingError("%s data type not found" % name)

View File

@ -351,25 +351,24 @@ class RangeCaster(object):
"""
from psycopg2.extensions import STATUS_IN_TRANSACTION
from psycopg2.extras import _solve_conn_curs
conn, curs = _solve_conn_curs(conn_or_curs)
with _solve_conn_curs(conn_or_curs) as (conn, curs):
if conn.info.server_version < 90200:
raise ProgrammingError("range types not available in version %s"
% conn.info.server_version)
if conn.info.server_version < 90200:
raise ProgrammingError("range types not available in version %s"
% conn.info.server_version)
# Store the transaction status of the connection to revert it after use
conn_status = conn.status
# Store the transaction status of the connection to revert it after use
conn_status = conn.status
# Use the correct schema
if '.' in name:
schema, tname = name.split('.', 1)
else:
tname = name
schema = 'public'
# Use the correct schema
if '.' in name:
schema, tname = name.split('.', 1)
else:
tname = name
schema = 'public'
# get the type oid and attributes
try:
curs.execute("""\
# get the type oid and attributes
try:
curs.execute("""\
select rngtypid, rngsubtype,
(select typarray from pg_type where oid = rngtypid)
from pg_range r
@ -378,17 +377,17 @@ join pg_namespace ns on ns.oid = typnamespace
where typname = %s and ns.nspname = %s;
""", (tname, schema))
except ProgrammingError:
if not conn.autocommit:
conn.rollback()
raise
else:
rec = curs.fetchone()
except ProgrammingError:
if not conn.autocommit:
conn.rollback()
raise
else:
rec = curs.fetchone()
# revert the status of the connection as before the command
if (conn_status != STATUS_IN_TRANSACTION
and not conn.autocommit):
conn.rollback()
# revert the status of the connection as before the command
if (conn_status != STATUS_IN_TRANSACTION
and not conn.autocommit):
conn.rollback()
if not rec:
raise ProgrammingError(

View File

@ -26,6 +26,7 @@ and classes until a better place in the distribution is found.
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import contextlib
import os as _os
import time as _time
import re as _re
@ -790,6 +791,7 @@ def wait_select(conn):
continue
@contextlib.contextmanager
def _solve_conn_curs(conn_or_curs):
"""Return the connection and a DBAPI cursor from a connection or cursor."""
if conn_or_curs is None:
@ -802,7 +804,8 @@ def _solve_conn_curs(conn_or_curs):
conn = conn_or_curs
curs = conn.cursor(cursor_factory=_cursor)
return conn, curs
with curs:
yield conn, curs
class HstoreAdapter(object):
@ -913,31 +916,30 @@ class HstoreAdapter(object):
def get_oids(self, conn_or_curs):
"""Return the lists of OID of the hstore and hstore[] types.
"""
conn, curs = _solve_conn_curs(conn_or_curs)
with _solve_conn_curs(conn_or_curs) as (conn, curs):
# Store the transaction status of the connection to revert it after use
conn_status = conn.status
# Store the transaction status of the connection to revert it after use
conn_status = conn.status
# column typarray not available before PG 8.3
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
# column typarray not available before PG 8.3
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
rv0, rv1 = [], []
rv0, rv1 = [], []
# get the oid for the hstore
curs.execute("""\
# get the oid for the hstore
curs.execute("""\
SELECT t.oid, %s
FROM pg_type t JOIN pg_namespace ns
ON typnamespace = ns.oid
WHERE typname = 'hstore';
""" % typarray)
for oids in curs:
rv0.append(oids[0])
rv1.append(oids[1])
for oids in curs:
rv0.append(oids[0])
rv1.append(oids[1])
# revert the status of the connection as before the command
if (conn_status != _ext.STATUS_IN_TRANSACTION
and not conn.autocommit):
conn.rollback()
# revert the status of the connection as before the command
if (conn_status != _ext.STATUS_IN_TRANSACTION
and not conn.autocommit):
conn.rollback()
return tuple(rv0), tuple(rv1)
@ -1092,23 +1094,22 @@ class CompositeCaster(object):
Raise `ProgrammingError` if the type is not found.
"""
conn, curs = _solve_conn_curs(conn_or_curs)
with _solve_conn_curs(conn_or_curs) as (conn, curs):
# Store the transaction status of the connection to revert it after use
conn_status = conn.status
# Store the transaction status of the connection to revert it after use
conn_status = conn.status
# Use the correct schema
if '.' in name:
schema, tname = name.split('.', 1)
else:
tname = name
schema = 'public'
# Use the correct schema
if '.' in name:
schema, tname = name.split('.', 1)
else:
tname = name
schema = 'public'
# column typarray not available before PG 8.3
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
# column typarray not available before PG 8.3
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
# get the type oid and attributes
curs.execute("""\
# get the type oid and attributes
curs.execute("""\
SELECT t.oid, %s, attname, atttypid
FROM pg_type t
JOIN pg_namespace ns ON typnamespace = ns.oid
@ -1118,12 +1119,12 @@ WHERE typname = %%s AND nspname = %%s
ORDER BY attnum;
""" % typarray, (tname, schema))
recs = curs.fetchall()
recs = curs.fetchall()
# revert the status of the connection as before the command
if (conn_status != _ext.STATUS_IN_TRANSACTION
and not conn.autocommit):
conn.rollback()
# revert the status of the connection as before the command
if (conn_status != _ext.STATUS_IN_TRANSACTION
and not conn.autocommit):
conn.rollback()
if not recs:
raise psycopg2.ProgrammingError(

View File

@ -36,7 +36,7 @@ from psycopg2._json import _get_json_oids
from psycopg2.extras import (
CompositeCaster, DateRange, DateTimeRange, DateTimeTZRange, HstoreAdapter,
Inet, Json, NumericRange, Range, RealDictConnection,
register_composite, register_hstore, register_range,
register_composite, register_hstore, register_range, _solve_conn_curs
)
from psycopg2.tz import FixedOffsetTimezone
@ -1632,6 +1632,32 @@ class RangeCasterTestCase(ConnectingTestCase):
cur.execute("rollback to savepoint x;")
class TestSolveConnCurs(ConnectingTestCase):
def test_pass_connection(self):
with _solve_conn_curs(self.conn) as (conn, curs):
self.assertIsInstance(conn, psycopg2.extensions.connection)
self.assertIsInstance(curs, psycopg2.extensions.cursor)
self.assertIs(conn, self.conn)
self.assertFalse(conn.closed)
self.assertFalse(curs.closed)
self.assertFalse(conn.closed)
self.assertTrue(curs.closed)
def test_pass_cursor(self):
with self.conn.cursor() as cursor:
with _solve_conn_curs(cursor) as (conn, curs):
self.assertIsInstance(conn, psycopg2.extensions.connection)
self.assertIsInstance(curs, psycopg2.extensions.cursor)
self.assertIs(conn, self.conn)
self.assertIsNot(curs, cursor)
self.assertFalse(conn.closed)
self.assertFalse(curs.closed)
self.assertFalse(conn.closed)
self.assertTrue(curs.closed)
self.assertFalse(conn.closed)
self.assertTrue(curs.closed)
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)