fix: look up for range types defined in schemas in the search path

This commit is contained in:
Daniele Varrazzo 2022-10-06 01:58:27 +01:00
parent 9535462ce9
commit ac25d3bdc0
3 changed files with 57 additions and 18 deletions

3
NEWS
View File

@ -1,7 +1,8 @@
What's new in psycopg 2.9.4 (unreleased) What's new in psycopg 2.9.4 (unreleased)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- Fix `register_composite()` with customized search_path (:ticket:`#1487`). - Fix `register_composite()`, `register_range()` with customized search_path
(:ticket:`#1487`).
- Handle correctly composite types with names or in schemas requiring escape. - Handle correctly composite types with names or in schemas requiring escape.

View File

@ -363,33 +363,54 @@ class RangeCaster:
schema = 'public' schema = 'public'
# get the type oid and attributes # get the type oid and attributes
try: curs.execute("""\
curs.execute("""\ select rngtypid, rngsubtype, typarray
select rngtypid, rngsubtype,
(select typarray from pg_type where oid = rngtypid)
from pg_range r from pg_range r
join pg_type t on t.oid = rngtypid join pg_type t on t.oid = rngtypid
join pg_namespace ns on ns.oid = typnamespace join pg_namespace ns on ns.oid = typnamespace
where typname = %s and ns.nspname = %s; where typname = %s and ns.nspname = %s;
""", (tname, schema)) """, (tname, schema))
rec = curs.fetchone()
except ProgrammingError: if not rec:
if not conn.autocommit: # The above algorithm doesn't work for customized seach_path
conn.rollback() # (#1487) The implementation below works better, but, to guarantee
raise # backwards compatibility, use it only if the original one failed.
else: try:
rec = curs.fetchone() savepoint = False
# Because we executed statements earlier, we are either INTRANS
# or we are IDLE only if the transaction is autocommit, in
# which case we don't need the savepoint anyway.
if conn.status == STATUS_IN_TRANSACTION:
curs.execute("SAVEPOINT register_type")
savepoint = True
# revert the status of the connection as before the command curs.execute("""\
if (conn_status != STATUS_IN_TRANSACTION SELECT rngtypid, rngsubtype, typarray, typname, nspname
and not conn.autocommit): from pg_range r
conn.rollback() join pg_type t on t.oid = rngtypid
join pg_namespace ns on ns.oid = typnamespace
WHERE t.oid = %s::regtype
""", (name, ))
except ProgrammingError:
pass
else:
rec = curs.fetchone()
if rec:
tname, schema = rec[3:]
finally:
if savepoint:
curs.execute("ROLLBACK TO SAVEPOINT register_type")
# revert the status of the connection as before the command
if conn_status != STATUS_IN_TRANSACTION and not conn.autocommit:
conn.rollback()
if not rec: if not rec:
raise ProgrammingError( raise ProgrammingError(
f"PostgreSQL type '{name}' not found") f"PostgreSQL range '{name}' not found")
type, subtype, array = rec type, subtype, array = rec[:3]
return RangeCaster(name, pyrange, return RangeCaster(name, pyrange,
oid=type, subtype_oid=subtype, array_oid=array) oid=type, subtype_oid=subtype, array_oid=array)

View File

@ -1615,6 +1615,18 @@ class RangeCasterTestCase(ConnectingTestCase):
cur = self.conn.cursor() cur = self.conn.cursor()
self.assertRaises(psycopg2.ProgrammingError, self.assertRaises(psycopg2.ProgrammingError,
register_range, 'nosuchrange', 'FailRange', cur) register_range, 'nosuchrange', 'FailRange', cur)
self.assertEqual(self.conn.status, ext.STATUS_READY)
cur.execute("select 1")
self.assertRaises(psycopg2.ProgrammingError,
register_range, 'nosuchrange', 'FailRange', cur)
self.assertEqual(self.conn.status, ext.STATUS_IN_TRANSACTION)
self.conn.rollback()
self.conn.autocommit = True
self.assertRaises(psycopg2.ProgrammingError,
register_range, 'nosuchrange', 'FailRange', cur)
@restore_types @restore_types
def test_schema_range(self): def test_schema_range(self):
@ -1629,7 +1641,7 @@ class RangeCasterTestCase(ConnectingTestCase):
register_range('r1', 'r1', cur) register_range('r1', 'r1', cur)
ra2 = register_range('r2', 'r2', cur) ra2 = register_range('r2', 'r2', cur)
rars2 = register_range('rs.r2', 'r2', cur) rars2 = register_range('rs.r2', 'r2', cur)
register_range('rs.r3', 'r3', cur) rars3 = register_range('rs.r3', 'r3', cur)
self.assertNotEqual( self.assertNotEqual(
ra2.typecaster.values[0], ra2.typecaster.values[0],
@ -1643,6 +1655,11 @@ class RangeCasterTestCase(ConnectingTestCase):
register_range, 'rs.r1', 'FailRange', cur) register_range, 'rs.r1', 'FailRange', cur)
cur.execute("rollback to savepoint x;") cur.execute("rollback to savepoint x;")
cur2 = self.conn.cursor()
cur2.execute("set local search_path to rs,public")
ra3 = register_range('r3', 'r3', cur2)
self.assertEqual(ra3.typecaster.values[0], rars3.typecaster.values[0])
def test_suite(): def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__) return unittest.TestLoader().loadTestsFromName(__name__)