mirror of
https://github.com/psycopg/psycopg2.git
synced 2025-07-29 17:39:49 +03:00
Merge e9ae67ff07
into 73969ba3e7
This commit is contained in:
commit
38b0e253d1
|
@ -180,8 +180,7 @@ def _get_json_oids(conn_or_curs, name='json'):
|
||||||
from psycopg2.extensions import STATUS_IN_TRANSACTION
|
from psycopg2.extensions import STATUS_IN_TRANSACTION
|
||||||
from psycopg2.extras import _solve_conn_curs
|
from psycopg2.extras import _solve_conn_curs
|
||||||
|
|
||||||
conn, curs = _solve_conn_curs(conn_or_curs)
|
with _solve_conn_curs(conn_or_curs) as (conn, curs):
|
||||||
|
|
||||||
# Store the transaction status of the connection to revert it after use
|
# Store the transaction status of the connection to revert it after use
|
||||||
conn_status = conn.status
|
conn_status = conn.status
|
||||||
|
|
||||||
|
|
|
@ -351,8 +351,7 @@ class RangeCaster(object):
|
||||||
"""
|
"""
|
||||||
from psycopg2.extensions import STATUS_IN_TRANSACTION
|
from psycopg2.extensions import STATUS_IN_TRANSACTION
|
||||||
from psycopg2.extras import _solve_conn_curs
|
from psycopg2.extras import _solve_conn_curs
|
||||||
conn, curs = _solve_conn_curs(conn_or_curs)
|
with _solve_conn_curs(conn_or_curs) as (conn, curs):
|
||||||
|
|
||||||
if conn.info.server_version < 90200:
|
if conn.info.server_version < 90200:
|
||||||
raise ProgrammingError("range types not available in version %s"
|
raise ProgrammingError("range types not available in version %s"
|
||||||
% conn.info.server_version)
|
% conn.info.server_version)
|
||||||
|
|
|
@ -26,6 +26,7 @@ and classes until a better place in the distribution is found.
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||||
# License for more details.
|
# License for more details.
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import os as _os
|
import os as _os
|
||||||
import time as _time
|
import time as _time
|
||||||
import re as _re
|
import re as _re
|
||||||
|
@ -790,6 +791,7 @@ def wait_select(conn):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
def _solve_conn_curs(conn_or_curs):
|
def _solve_conn_curs(conn_or_curs):
|
||||||
"""Return the connection and a DBAPI cursor from a connection or cursor."""
|
"""Return the connection and a DBAPI cursor from a connection or cursor."""
|
||||||
if conn_or_curs is None:
|
if conn_or_curs is None:
|
||||||
|
@ -802,7 +804,8 @@ def _solve_conn_curs(conn_or_curs):
|
||||||
conn = conn_or_curs
|
conn = conn_or_curs
|
||||||
curs = conn.cursor(cursor_factory=_cursor)
|
curs = conn.cursor(cursor_factory=_cursor)
|
||||||
|
|
||||||
return conn, curs
|
with curs:
|
||||||
|
yield conn, curs
|
||||||
|
|
||||||
|
|
||||||
class HstoreAdapter(object):
|
class HstoreAdapter(object):
|
||||||
|
@ -913,8 +916,7 @@ class HstoreAdapter(object):
|
||||||
def get_oids(self, conn_or_curs):
|
def get_oids(self, conn_or_curs):
|
||||||
"""Return the lists of OID of the hstore and hstore[] types.
|
"""Return the lists of OID of the hstore and hstore[] types.
|
||||||
"""
|
"""
|
||||||
conn, curs = _solve_conn_curs(conn_or_curs)
|
with _solve_conn_curs(conn_or_curs) as (conn, curs):
|
||||||
|
|
||||||
# Store the transaction status of the connection to revert it after use
|
# Store the transaction status of the connection to revert it after use
|
||||||
conn_status = conn.status
|
conn_status = conn.status
|
||||||
|
|
||||||
|
@ -1092,8 +1094,7 @@ class CompositeCaster(object):
|
||||||
|
|
||||||
Raise `ProgrammingError` if the type is not found.
|
Raise `ProgrammingError` if the type is not found.
|
||||||
"""
|
"""
|
||||||
conn, curs = _solve_conn_curs(conn_or_curs)
|
with _solve_conn_curs(conn_or_curs) as (conn, curs):
|
||||||
|
|
||||||
# Store the transaction status of the connection to revert it after use
|
# Store the transaction status of the connection to revert it after use
|
||||||
conn_status = conn.status
|
conn_status = conn.status
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ from psycopg2._json import _get_json_oids
|
||||||
from psycopg2.extras import (
|
from psycopg2.extras import (
|
||||||
CompositeCaster, DateRange, DateTimeRange, DateTimeTZRange, HstoreAdapter,
|
CompositeCaster, DateRange, DateTimeRange, DateTimeTZRange, HstoreAdapter,
|
||||||
Inet, Json, NumericRange, Range, RealDictConnection,
|
Inet, Json, NumericRange, Range, RealDictConnection,
|
||||||
register_composite, register_hstore, register_range,
|
register_composite, register_hstore, register_range, _solve_conn_curs
|
||||||
)
|
)
|
||||||
from psycopg2.tz import FixedOffsetTimezone
|
from psycopg2.tz import FixedOffsetTimezone
|
||||||
|
|
||||||
|
@ -1632,6 +1632,32 @@ class RangeCasterTestCase(ConnectingTestCase):
|
||||||
cur.execute("rollback to savepoint x;")
|
cur.execute("rollback to savepoint x;")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSolveConnCurs(ConnectingTestCase):
|
||||||
|
def test_pass_connection(self):
|
||||||
|
with _solve_conn_curs(self.conn) as (conn, curs):
|
||||||
|
self.assertIsInstance(conn, psycopg2.extensions.connection)
|
||||||
|
self.assertIsInstance(curs, psycopg2.extensions.cursor)
|
||||||
|
self.assertIs(conn, self.conn)
|
||||||
|
self.assertFalse(conn.closed)
|
||||||
|
self.assertFalse(curs.closed)
|
||||||
|
self.assertFalse(conn.closed)
|
||||||
|
self.assertTrue(curs.closed)
|
||||||
|
|
||||||
|
def test_pass_cursor(self):
|
||||||
|
with self.conn.cursor() as cursor:
|
||||||
|
with _solve_conn_curs(cursor) as (conn, curs):
|
||||||
|
self.assertIsInstance(conn, psycopg2.extensions.connection)
|
||||||
|
self.assertIsInstance(curs, psycopg2.extensions.cursor)
|
||||||
|
self.assertIs(conn, self.conn)
|
||||||
|
self.assertIsNot(curs, cursor)
|
||||||
|
self.assertFalse(conn.closed)
|
||||||
|
self.assertFalse(curs.closed)
|
||||||
|
self.assertFalse(conn.closed)
|
||||||
|
self.assertTrue(curs.closed)
|
||||||
|
self.assertFalse(conn.closed)
|
||||||
|
self.assertTrue(curs.closed)
|
||||||
|
|
||||||
|
|
||||||
def test_suite():
|
def test_suite():
|
||||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user