mirror of
https://github.com/psycopg/psycopg2.git
synced 2025-07-28 17:10:05 +03:00
Improve resource ownership semantics of _solve_conn_curs()
A new cursor is always creates so always close it once finished.
This commit is contained in:
parent
9bcca1a7b0
commit
e9ae67ff07
27
lib/_json.py
27
lib/_json.py
|
@ -180,23 +180,22 @@ def _get_json_oids(conn_or_curs, name='json'):
|
|||
from psycopg2.extensions import STATUS_IN_TRANSACTION
|
||||
from psycopg2.extras import _solve_conn_curs
|
||||
|
||||
conn, curs = _solve_conn_curs(conn_or_curs)
|
||||
with _solve_conn_curs(conn_or_curs) as (conn, curs):
|
||||
# Store the transaction status of the connection to revert it after use
|
||||
conn_status = conn.status
|
||||
|
||||
# Store the transaction status of the connection to revert it after use
|
||||
conn_status = conn.status
|
||||
# column typarray not available before PG 8.3
|
||||
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
|
||||
|
||||
# column typarray not available before PG 8.3
|
||||
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
|
||||
# get the oid for the hstore
|
||||
curs.execute(
|
||||
"SELECT t.oid, %s FROM pg_type t WHERE t.typname = %%s;"
|
||||
% typarray, (name,))
|
||||
r = curs.fetchone()
|
||||
|
||||
# get the oid for the hstore
|
||||
curs.execute(
|
||||
"SELECT t.oid, %s FROM pg_type t WHERE t.typname = %%s;"
|
||||
% typarray, (name,))
|
||||
r = curs.fetchone()
|
||||
|
||||
# revert the status of the connection as before the command
|
||||
if conn_status != STATUS_IN_TRANSACTION and not conn.autocommit:
|
||||
conn.rollback()
|
||||
# revert the status of the connection as before the command
|
||||
if conn_status != STATUS_IN_TRANSACTION and not conn.autocommit:
|
||||
conn.rollback()
|
||||
|
||||
if not r:
|
||||
raise conn.ProgrammingError("%s data type not found" % name)
|
||||
|
|
|
@ -351,25 +351,24 @@ class RangeCaster(object):
|
|||
"""
|
||||
from psycopg2.extensions import STATUS_IN_TRANSACTION
|
||||
from psycopg2.extras import _solve_conn_curs
|
||||
conn, curs = _solve_conn_curs(conn_or_curs)
|
||||
with _solve_conn_curs(conn_or_curs) as (conn, curs):
|
||||
if conn.info.server_version < 90200:
|
||||
raise ProgrammingError("range types not available in version %s"
|
||||
% conn.info.server_version)
|
||||
|
||||
if conn.info.server_version < 90200:
|
||||
raise ProgrammingError("range types not available in version %s"
|
||||
% conn.info.server_version)
|
||||
# Store the transaction status of the connection to revert it after use
|
||||
conn_status = conn.status
|
||||
|
||||
# Store the transaction status of the connection to revert it after use
|
||||
conn_status = conn.status
|
||||
# Use the correct schema
|
||||
if '.' in name:
|
||||
schema, tname = name.split('.', 1)
|
||||
else:
|
||||
tname = name
|
||||
schema = 'public'
|
||||
|
||||
# Use the correct schema
|
||||
if '.' in name:
|
||||
schema, tname = name.split('.', 1)
|
||||
else:
|
||||
tname = name
|
||||
schema = 'public'
|
||||
|
||||
# get the type oid and attributes
|
||||
try:
|
||||
curs.execute("""\
|
||||
# get the type oid and attributes
|
||||
try:
|
||||
curs.execute("""\
|
||||
select rngtypid, rngsubtype,
|
||||
(select typarray from pg_type where oid = rngtypid)
|
||||
from pg_range r
|
||||
|
@ -378,17 +377,17 @@ join pg_namespace ns on ns.oid = typnamespace
|
|||
where typname = %s and ns.nspname = %s;
|
||||
""", (tname, schema))
|
||||
|
||||
except ProgrammingError:
|
||||
if not conn.autocommit:
|
||||
conn.rollback()
|
||||
raise
|
||||
else:
|
||||
rec = curs.fetchone()
|
||||
except ProgrammingError:
|
||||
if not conn.autocommit:
|
||||
conn.rollback()
|
||||
raise
|
||||
else:
|
||||
rec = curs.fetchone()
|
||||
|
||||
# revert the status of the connection as before the command
|
||||
if (conn_status != STATUS_IN_TRANSACTION
|
||||
and not conn.autocommit):
|
||||
conn.rollback()
|
||||
# revert the status of the connection as before the command
|
||||
if (conn_status != STATUS_IN_TRANSACTION
|
||||
and not conn.autocommit):
|
||||
conn.rollback()
|
||||
|
||||
if not rec:
|
||||
raise ProgrammingError(
|
||||
|
|
|
@ -26,6 +26,7 @@ and classes until a better place in the distribution is found.
|
|||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
import contextlib
|
||||
import os as _os
|
||||
import time as _time
|
||||
import re as _re
|
||||
|
@ -781,6 +782,7 @@ def wait_select(conn):
|
|||
continue
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _solve_conn_curs(conn_or_curs):
|
||||
"""Return the connection and a DBAPI cursor from a connection or cursor."""
|
||||
if conn_or_curs is None:
|
||||
|
@ -793,7 +795,8 @@ def _solve_conn_curs(conn_or_curs):
|
|||
conn = conn_or_curs
|
||||
curs = conn.cursor(cursor_factory=_cursor)
|
||||
|
||||
return conn, curs
|
||||
with curs:
|
||||
yield conn, curs
|
||||
|
||||
|
||||
class HstoreAdapter(object):
|
||||
|
@ -904,31 +907,30 @@ class HstoreAdapter(object):
|
|||
def get_oids(self, conn_or_curs):
|
||||
"""Return the lists of OID of the hstore and hstore[] types.
|
||||
"""
|
||||
conn, curs = _solve_conn_curs(conn_or_curs)
|
||||
with _solve_conn_curs(conn_or_curs) as (conn, curs):
|
||||
# Store the transaction status of the connection to revert it after use
|
||||
conn_status = conn.status
|
||||
|
||||
# Store the transaction status of the connection to revert it after use
|
||||
conn_status = conn.status
|
||||
# column typarray not available before PG 8.3
|
||||
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
|
||||
|
||||
# column typarray not available before PG 8.3
|
||||
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
|
||||
rv0, rv1 = [], []
|
||||
|
||||
rv0, rv1 = [], []
|
||||
|
||||
# get the oid for the hstore
|
||||
curs.execute("""\
|
||||
# get the oid for the hstore
|
||||
curs.execute("""\
|
||||
SELECT t.oid, %s
|
||||
FROM pg_type t JOIN pg_namespace ns
|
||||
ON typnamespace = ns.oid
|
||||
WHERE typname = 'hstore';
|
||||
""" % typarray)
|
||||
for oids in curs:
|
||||
rv0.append(oids[0])
|
||||
rv1.append(oids[1])
|
||||
for oids in curs:
|
||||
rv0.append(oids[0])
|
||||
rv1.append(oids[1])
|
||||
|
||||
# revert the status of the connection as before the command
|
||||
if (conn_status != _ext.STATUS_IN_TRANSACTION
|
||||
and not conn.autocommit):
|
||||
conn.rollback()
|
||||
# revert the status of the connection as before the command
|
||||
if (conn_status != _ext.STATUS_IN_TRANSACTION
|
||||
and not conn.autocommit):
|
||||
conn.rollback()
|
||||
|
||||
return tuple(rv0), tuple(rv1)
|
||||
|
||||
|
@ -1083,23 +1085,22 @@ class CompositeCaster(object):
|
|||
|
||||
Raise `ProgrammingError` if the type is not found.
|
||||
"""
|
||||
conn, curs = _solve_conn_curs(conn_or_curs)
|
||||
with _solve_conn_curs(conn_or_curs) as (conn, curs):
|
||||
# Store the transaction status of the connection to revert it after use
|
||||
conn_status = conn.status
|
||||
|
||||
# Store the transaction status of the connection to revert it after use
|
||||
conn_status = conn.status
|
||||
# Use the correct schema
|
||||
if '.' in name:
|
||||
schema, tname = name.split('.', 1)
|
||||
else:
|
||||
tname = name
|
||||
schema = 'public'
|
||||
|
||||
# Use the correct schema
|
||||
if '.' in name:
|
||||
schema, tname = name.split('.', 1)
|
||||
else:
|
||||
tname = name
|
||||
schema = 'public'
|
||||
# column typarray not available before PG 8.3
|
||||
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
|
||||
|
||||
# column typarray not available before PG 8.3
|
||||
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
|
||||
|
||||
# get the type oid and attributes
|
||||
curs.execute("""\
|
||||
# get the type oid and attributes
|
||||
curs.execute("""\
|
||||
SELECT t.oid, %s, attname, atttypid
|
||||
FROM pg_type t
|
||||
JOIN pg_namespace ns ON typnamespace = ns.oid
|
||||
|
@ -1109,12 +1110,12 @@ WHERE typname = %%s AND nspname = %%s
|
|||
ORDER BY attnum;
|
||||
""" % typarray, (tname, schema))
|
||||
|
||||
recs = curs.fetchall()
|
||||
recs = curs.fetchall()
|
||||
|
||||
# revert the status of the connection as before the command
|
||||
if (conn_status != _ext.STATUS_IN_TRANSACTION
|
||||
and not conn.autocommit):
|
||||
conn.rollback()
|
||||
# revert the status of the connection as before the command
|
||||
if (conn_status != _ext.STATUS_IN_TRANSACTION
|
||||
and not conn.autocommit):
|
||||
conn.rollback()
|
||||
|
||||
if not recs:
|
||||
raise psycopg2.ProgrammingError(
|
||||
|
|
|
@ -36,7 +36,7 @@ from psycopg2._json import _get_json_oids
|
|||
from psycopg2.extras import (
|
||||
CompositeCaster, DateRange, DateTimeRange, DateTimeTZRange, HstoreAdapter,
|
||||
Inet, Json, NumericRange, Range, RealDictConnection,
|
||||
register_composite, register_hstore, register_range,
|
||||
register_composite, register_hstore, register_range, _solve_conn_curs
|
||||
)
|
||||
from psycopg2.tz import FixedOffsetTimezone
|
||||
|
||||
|
@ -1624,6 +1624,32 @@ class RangeCasterTestCase(ConnectingTestCase):
|
|||
cur.execute("rollback to savepoint x;")
|
||||
|
||||
|
||||
class TestSolveConnCurs(ConnectingTestCase):
|
||||
def test_pass_connection(self):
|
||||
with _solve_conn_curs(self.conn) as (conn, curs):
|
||||
self.assertIsInstance(conn, psycopg2.extensions.connection)
|
||||
self.assertIsInstance(curs, psycopg2.extensions.cursor)
|
||||
self.assertIs(conn, self.conn)
|
||||
self.assertFalse(conn.closed)
|
||||
self.assertFalse(curs.closed)
|
||||
self.assertFalse(conn.closed)
|
||||
self.assertTrue(curs.closed)
|
||||
|
||||
def test_pass_cursor(self):
|
||||
with self.conn.cursor() as cursor:
|
||||
with _solve_conn_curs(cursor) as (conn, curs):
|
||||
self.assertIsInstance(conn, psycopg2.extensions.connection)
|
||||
self.assertIsInstance(curs, psycopg2.extensions.cursor)
|
||||
self.assertIs(conn, self.conn)
|
||||
self.assertIsNot(curs, cursor)
|
||||
self.assertFalse(conn.closed)
|
||||
self.assertFalse(curs.closed)
|
||||
self.assertFalse(conn.closed)
|
||||
self.assertTrue(curs.closed)
|
||||
self.assertFalse(conn.closed)
|
||||
self.assertTrue(curs.closed)
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user