mirror of
https://github.com/psycopg/psycopg2.git
synced 2024-11-21 16:36:34 +03:00
fix: look up for range types defined in schemas in the search path
This commit is contained in:
parent
9535462ce9
commit
ac25d3bdc0
3
NEWS
3
NEWS
|
@ -1,7 +1,8 @@
|
|||
What's new in psycopg 2.9.4 (unreleased)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- Fix `register_composite()` with customized search_path (:ticket:`#1487`).
|
||||
- Fix `register_composite()`, `register_range()` with customized search_path
|
||||
(:ticket:`#1487`).
|
||||
- Handle correctly composite types with names or in schemas requiring escape.
|
||||
|
||||
|
||||
|
|
|
@ -363,33 +363,54 @@ class RangeCaster:
|
|||
schema = 'public'
|
||||
|
||||
# get the type oid and attributes
|
||||
try:
|
||||
curs.execute("""\
|
||||
select rngtypid, rngsubtype,
|
||||
(select typarray from pg_type where oid = rngtypid)
|
||||
curs.execute("""\
|
||||
select rngtypid, rngsubtype, typarray
|
||||
from pg_range r
|
||||
join pg_type t on t.oid = rngtypid
|
||||
join pg_namespace ns on ns.oid = typnamespace
|
||||
where typname = %s and ns.nspname = %s;
|
||||
""", (tname, schema))
|
||||
rec = curs.fetchone()
|
||||
|
||||
except ProgrammingError:
|
||||
if not conn.autocommit:
|
||||
conn.rollback()
|
||||
raise
|
||||
else:
|
||||
rec = curs.fetchone()
|
||||
if not rec:
|
||||
# The above algorithm doesn't work for customized seach_path
|
||||
# (#1487) The implementation below works better, but, to guarantee
|
||||
# backwards compatibility, use it only if the original one failed.
|
||||
try:
|
||||
savepoint = False
|
||||
# Because we executed statements earlier, we are either INTRANS
|
||||
# or we are IDLE only if the transaction is autocommit, in
|
||||
# which case we don't need the savepoint anyway.
|
||||
if conn.status == STATUS_IN_TRANSACTION:
|
||||
curs.execute("SAVEPOINT register_type")
|
||||
savepoint = True
|
||||
|
||||
# revert the status of the connection as before the command
|
||||
if (conn_status != STATUS_IN_TRANSACTION
|
||||
and not conn.autocommit):
|
||||
conn.rollback()
|
||||
curs.execute("""\
|
||||
SELECT rngtypid, rngsubtype, typarray, typname, nspname
|
||||
from pg_range r
|
||||
join pg_type t on t.oid = rngtypid
|
||||
join pg_namespace ns on ns.oid = typnamespace
|
||||
WHERE t.oid = %s::regtype
|
||||
""", (name, ))
|
||||
except ProgrammingError:
|
||||
pass
|
||||
else:
|
||||
rec = curs.fetchone()
|
||||
if rec:
|
||||
tname, schema = rec[3:]
|
||||
finally:
|
||||
if savepoint:
|
||||
curs.execute("ROLLBACK TO SAVEPOINT register_type")
|
||||
|
||||
# revert the status of the connection as before the command
|
||||
if conn_status != STATUS_IN_TRANSACTION and not conn.autocommit:
|
||||
conn.rollback()
|
||||
|
||||
if not rec:
|
||||
raise ProgrammingError(
|
||||
f"PostgreSQL type '{name}' not found")
|
||||
f"PostgreSQL range '{name}' not found")
|
||||
|
||||
type, subtype, array = rec
|
||||
type, subtype, array = rec[:3]
|
||||
|
||||
return RangeCaster(name, pyrange,
|
||||
oid=type, subtype_oid=subtype, array_oid=array)
|
||||
|
|
|
@ -1615,6 +1615,18 @@ class RangeCasterTestCase(ConnectingTestCase):
|
|||
cur = self.conn.cursor()
|
||||
self.assertRaises(psycopg2.ProgrammingError,
|
||||
register_range, 'nosuchrange', 'FailRange', cur)
|
||||
self.assertEqual(self.conn.status, ext.STATUS_READY)
|
||||
|
||||
cur.execute("select 1")
|
||||
self.assertRaises(psycopg2.ProgrammingError,
|
||||
register_range, 'nosuchrange', 'FailRange', cur)
|
||||
|
||||
self.assertEqual(self.conn.status, ext.STATUS_IN_TRANSACTION)
|
||||
|
||||
self.conn.rollback()
|
||||
self.conn.autocommit = True
|
||||
self.assertRaises(psycopg2.ProgrammingError,
|
||||
register_range, 'nosuchrange', 'FailRange', cur)
|
||||
|
||||
@restore_types
|
||||
def test_schema_range(self):
|
||||
|
@ -1629,7 +1641,7 @@ class RangeCasterTestCase(ConnectingTestCase):
|
|||
register_range('r1', 'r1', cur)
|
||||
ra2 = register_range('r2', 'r2', cur)
|
||||
rars2 = register_range('rs.r2', 'r2', cur)
|
||||
register_range('rs.r3', 'r3', cur)
|
||||
rars3 = register_range('rs.r3', 'r3', cur)
|
||||
|
||||
self.assertNotEqual(
|
||||
ra2.typecaster.values[0],
|
||||
|
@ -1643,6 +1655,11 @@ class RangeCasterTestCase(ConnectingTestCase):
|
|||
register_range, 'rs.r1', 'FailRange', cur)
|
||||
cur.execute("rollback to savepoint x;")
|
||||
|
||||
cur2 = self.conn.cursor()
|
||||
cur2.execute("set local search_path to rs,public")
|
||||
ra3 = register_range('r3', 'r3', cur2)
|
||||
self.assertEqual(ra3.typecaster.values[0], rars3.typecaster.values[0])
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
|
Loading…
Reference in New Issue
Block a user