mirror of
https://github.com/psycopg/psycopg2.git
synced 2025-07-29 09:29:46 +03:00
Merge e9ae67ff07
into 73969ba3e7
This commit is contained in:
commit
38b0e253d1
27
lib/_json.py
27
lib/_json.py
|
@ -180,23 +180,22 @@ def _get_json_oids(conn_or_curs, name='json'):
|
||||||
from psycopg2.extensions import STATUS_IN_TRANSACTION
|
from psycopg2.extensions import STATUS_IN_TRANSACTION
|
||||||
from psycopg2.extras import _solve_conn_curs
|
from psycopg2.extras import _solve_conn_curs
|
||||||
|
|
||||||
conn, curs = _solve_conn_curs(conn_or_curs)
|
with _solve_conn_curs(conn_or_curs) as (conn, curs):
|
||||||
|
# Store the transaction status of the connection to revert it after use
|
||||||
|
conn_status = conn.status
|
||||||
|
|
||||||
# Store the transaction status of the connection to revert it after use
|
# column typarray not available before PG 8.3
|
||||||
conn_status = conn.status
|
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
|
||||||
|
|
||||||
# column typarray not available before PG 8.3
|
# get the oid for the hstore
|
||||||
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
|
curs.execute(
|
||||||
|
"SELECT t.oid, %s FROM pg_type t WHERE t.typname = %%s;"
|
||||||
|
% typarray, (name,))
|
||||||
|
r = curs.fetchone()
|
||||||
|
|
||||||
# get the oid for the hstore
|
# revert the status of the connection as before the command
|
||||||
curs.execute(
|
if conn_status != STATUS_IN_TRANSACTION and not conn.autocommit:
|
||||||
"SELECT t.oid, %s FROM pg_type t WHERE t.typname = %%s;"
|
conn.rollback()
|
||||||
% typarray, (name,))
|
|
||||||
r = curs.fetchone()
|
|
||||||
|
|
||||||
# revert the status of the connection as before the command
|
|
||||||
if conn_status != STATUS_IN_TRANSACTION and not conn.autocommit:
|
|
||||||
conn.rollback()
|
|
||||||
|
|
||||||
if not r:
|
if not r:
|
||||||
raise conn.ProgrammingError("%s data type not found" % name)
|
raise conn.ProgrammingError("%s data type not found" % name)
|
||||||
|
|
|
@ -351,25 +351,24 @@ class RangeCaster(object):
|
||||||
"""
|
"""
|
||||||
from psycopg2.extensions import STATUS_IN_TRANSACTION
|
from psycopg2.extensions import STATUS_IN_TRANSACTION
|
||||||
from psycopg2.extras import _solve_conn_curs
|
from psycopg2.extras import _solve_conn_curs
|
||||||
conn, curs = _solve_conn_curs(conn_or_curs)
|
with _solve_conn_curs(conn_or_curs) as (conn, curs):
|
||||||
|
if conn.info.server_version < 90200:
|
||||||
|
raise ProgrammingError("range types not available in version %s"
|
||||||
|
% conn.info.server_version)
|
||||||
|
|
||||||
if conn.info.server_version < 90200:
|
# Store the transaction status of the connection to revert it after use
|
||||||
raise ProgrammingError("range types not available in version %s"
|
conn_status = conn.status
|
||||||
% conn.info.server_version)
|
|
||||||
|
|
||||||
# Store the transaction status of the connection to revert it after use
|
# Use the correct schema
|
||||||
conn_status = conn.status
|
if '.' in name:
|
||||||
|
schema, tname = name.split('.', 1)
|
||||||
|
else:
|
||||||
|
tname = name
|
||||||
|
schema = 'public'
|
||||||
|
|
||||||
# Use the correct schema
|
# get the type oid and attributes
|
||||||
if '.' in name:
|
try:
|
||||||
schema, tname = name.split('.', 1)
|
curs.execute("""\
|
||||||
else:
|
|
||||||
tname = name
|
|
||||||
schema = 'public'
|
|
||||||
|
|
||||||
# get the type oid and attributes
|
|
||||||
try:
|
|
||||||
curs.execute("""\
|
|
||||||
select rngtypid, rngsubtype,
|
select rngtypid, rngsubtype,
|
||||||
(select typarray from pg_type where oid = rngtypid)
|
(select typarray from pg_type where oid = rngtypid)
|
||||||
from pg_range r
|
from pg_range r
|
||||||
|
@ -378,17 +377,17 @@ join pg_namespace ns on ns.oid = typnamespace
|
||||||
where typname = %s and ns.nspname = %s;
|
where typname = %s and ns.nspname = %s;
|
||||||
""", (tname, schema))
|
""", (tname, schema))
|
||||||
|
|
||||||
except ProgrammingError:
|
except ProgrammingError:
|
||||||
if not conn.autocommit:
|
if not conn.autocommit:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
rec = curs.fetchone()
|
rec = curs.fetchone()
|
||||||
|
|
||||||
# revert the status of the connection as before the command
|
# revert the status of the connection as before the command
|
||||||
if (conn_status != STATUS_IN_TRANSACTION
|
if (conn_status != STATUS_IN_TRANSACTION
|
||||||
and not conn.autocommit):
|
and not conn.autocommit):
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
|
|
||||||
if not rec:
|
if not rec:
|
||||||
raise ProgrammingError(
|
raise ProgrammingError(
|
||||||
|
|
|
@ -26,6 +26,7 @@ and classes until a better place in the distribution is found.
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||||
# License for more details.
|
# License for more details.
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import os as _os
|
import os as _os
|
||||||
import time as _time
|
import time as _time
|
||||||
import re as _re
|
import re as _re
|
||||||
|
@ -790,6 +791,7 @@ def wait_select(conn):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
def _solve_conn_curs(conn_or_curs):
|
def _solve_conn_curs(conn_or_curs):
|
||||||
"""Return the connection and a DBAPI cursor from a connection or cursor."""
|
"""Return the connection and a DBAPI cursor from a connection or cursor."""
|
||||||
if conn_or_curs is None:
|
if conn_or_curs is None:
|
||||||
|
@ -802,7 +804,8 @@ def _solve_conn_curs(conn_or_curs):
|
||||||
conn = conn_or_curs
|
conn = conn_or_curs
|
||||||
curs = conn.cursor(cursor_factory=_cursor)
|
curs = conn.cursor(cursor_factory=_cursor)
|
||||||
|
|
||||||
return conn, curs
|
with curs:
|
||||||
|
yield conn, curs
|
||||||
|
|
||||||
|
|
||||||
class HstoreAdapter(object):
|
class HstoreAdapter(object):
|
||||||
|
@ -913,31 +916,30 @@ class HstoreAdapter(object):
|
||||||
def get_oids(self, conn_or_curs):
|
def get_oids(self, conn_or_curs):
|
||||||
"""Return the lists of OID of the hstore and hstore[] types.
|
"""Return the lists of OID of the hstore and hstore[] types.
|
||||||
"""
|
"""
|
||||||
conn, curs = _solve_conn_curs(conn_or_curs)
|
with _solve_conn_curs(conn_or_curs) as (conn, curs):
|
||||||
|
# Store the transaction status of the connection to revert it after use
|
||||||
|
conn_status = conn.status
|
||||||
|
|
||||||
# Store the transaction status of the connection to revert it after use
|
# column typarray not available before PG 8.3
|
||||||
conn_status = conn.status
|
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
|
||||||
|
|
||||||
# column typarray not available before PG 8.3
|
rv0, rv1 = [], []
|
||||||
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
|
|
||||||
|
|
||||||
rv0, rv1 = [], []
|
# get the oid for the hstore
|
||||||
|
curs.execute("""\
|
||||||
# get the oid for the hstore
|
|
||||||
curs.execute("""\
|
|
||||||
SELECT t.oid, %s
|
SELECT t.oid, %s
|
||||||
FROM pg_type t JOIN pg_namespace ns
|
FROM pg_type t JOIN pg_namespace ns
|
||||||
ON typnamespace = ns.oid
|
ON typnamespace = ns.oid
|
||||||
WHERE typname = 'hstore';
|
WHERE typname = 'hstore';
|
||||||
""" % typarray)
|
""" % typarray)
|
||||||
for oids in curs:
|
for oids in curs:
|
||||||
rv0.append(oids[0])
|
rv0.append(oids[0])
|
||||||
rv1.append(oids[1])
|
rv1.append(oids[1])
|
||||||
|
|
||||||
# revert the status of the connection as before the command
|
# revert the status of the connection as before the command
|
||||||
if (conn_status != _ext.STATUS_IN_TRANSACTION
|
if (conn_status != _ext.STATUS_IN_TRANSACTION
|
||||||
and not conn.autocommit):
|
and not conn.autocommit):
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
|
|
||||||
return tuple(rv0), tuple(rv1)
|
return tuple(rv0), tuple(rv1)
|
||||||
|
|
||||||
|
@ -1092,23 +1094,22 @@ class CompositeCaster(object):
|
||||||
|
|
||||||
Raise `ProgrammingError` if the type is not found.
|
Raise `ProgrammingError` if the type is not found.
|
||||||
"""
|
"""
|
||||||
conn, curs = _solve_conn_curs(conn_or_curs)
|
with _solve_conn_curs(conn_or_curs) as (conn, curs):
|
||||||
|
# Store the transaction status of the connection to revert it after use
|
||||||
|
conn_status = conn.status
|
||||||
|
|
||||||
# Store the transaction status of the connection to revert it after use
|
# Use the correct schema
|
||||||
conn_status = conn.status
|
if '.' in name:
|
||||||
|
schema, tname = name.split('.', 1)
|
||||||
|
else:
|
||||||
|
tname = name
|
||||||
|
schema = 'public'
|
||||||
|
|
||||||
# Use the correct schema
|
# column typarray not available before PG 8.3
|
||||||
if '.' in name:
|
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
|
||||||
schema, tname = name.split('.', 1)
|
|
||||||
else:
|
|
||||||
tname = name
|
|
||||||
schema = 'public'
|
|
||||||
|
|
||||||
# column typarray not available before PG 8.3
|
# get the type oid and attributes
|
||||||
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
|
curs.execute("""\
|
||||||
|
|
||||||
# get the type oid and attributes
|
|
||||||
curs.execute("""\
|
|
||||||
SELECT t.oid, %s, attname, atttypid
|
SELECT t.oid, %s, attname, atttypid
|
||||||
FROM pg_type t
|
FROM pg_type t
|
||||||
JOIN pg_namespace ns ON typnamespace = ns.oid
|
JOIN pg_namespace ns ON typnamespace = ns.oid
|
||||||
|
@ -1118,12 +1119,12 @@ WHERE typname = %%s AND nspname = %%s
|
||||||
ORDER BY attnum;
|
ORDER BY attnum;
|
||||||
""" % typarray, (tname, schema))
|
""" % typarray, (tname, schema))
|
||||||
|
|
||||||
recs = curs.fetchall()
|
recs = curs.fetchall()
|
||||||
|
|
||||||
# revert the status of the connection as before the command
|
# revert the status of the connection as before the command
|
||||||
if (conn_status != _ext.STATUS_IN_TRANSACTION
|
if (conn_status != _ext.STATUS_IN_TRANSACTION
|
||||||
and not conn.autocommit):
|
and not conn.autocommit):
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
|
|
||||||
if not recs:
|
if not recs:
|
||||||
raise psycopg2.ProgrammingError(
|
raise psycopg2.ProgrammingError(
|
||||||
|
|
|
@ -36,7 +36,7 @@ from psycopg2._json import _get_json_oids
|
||||||
from psycopg2.extras import (
|
from psycopg2.extras import (
|
||||||
CompositeCaster, DateRange, DateTimeRange, DateTimeTZRange, HstoreAdapter,
|
CompositeCaster, DateRange, DateTimeRange, DateTimeTZRange, HstoreAdapter,
|
||||||
Inet, Json, NumericRange, Range, RealDictConnection,
|
Inet, Json, NumericRange, Range, RealDictConnection,
|
||||||
register_composite, register_hstore, register_range,
|
register_composite, register_hstore, register_range, _solve_conn_curs
|
||||||
)
|
)
|
||||||
from psycopg2.tz import FixedOffsetTimezone
|
from psycopg2.tz import FixedOffsetTimezone
|
||||||
|
|
||||||
|
@ -1632,6 +1632,32 @@ class RangeCasterTestCase(ConnectingTestCase):
|
||||||
cur.execute("rollback to savepoint x;")
|
cur.execute("rollback to savepoint x;")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSolveConnCurs(ConnectingTestCase):
|
||||||
|
def test_pass_connection(self):
|
||||||
|
with _solve_conn_curs(self.conn) as (conn, curs):
|
||||||
|
self.assertIsInstance(conn, psycopg2.extensions.connection)
|
||||||
|
self.assertIsInstance(curs, psycopg2.extensions.cursor)
|
||||||
|
self.assertIs(conn, self.conn)
|
||||||
|
self.assertFalse(conn.closed)
|
||||||
|
self.assertFalse(curs.closed)
|
||||||
|
self.assertFalse(conn.closed)
|
||||||
|
self.assertTrue(curs.closed)
|
||||||
|
|
||||||
|
def test_pass_cursor(self):
|
||||||
|
with self.conn.cursor() as cursor:
|
||||||
|
with _solve_conn_curs(cursor) as (conn, curs):
|
||||||
|
self.assertIsInstance(conn, psycopg2.extensions.connection)
|
||||||
|
self.assertIsInstance(curs, psycopg2.extensions.cursor)
|
||||||
|
self.assertIs(conn, self.conn)
|
||||||
|
self.assertIsNot(curs, cursor)
|
||||||
|
self.assertFalse(conn.closed)
|
||||||
|
self.assertFalse(curs.closed)
|
||||||
|
self.assertFalse(conn.closed)
|
||||||
|
self.assertTrue(curs.closed)
|
||||||
|
self.assertFalse(conn.closed)
|
||||||
|
self.assertTrue(curs.closed)
|
||||||
|
|
||||||
|
|
||||||
def test_suite():
|
def test_suite():
|
||||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user