Merge branch 'fix-1487'

This commit is contained in:
Daniele Varrazzo 2022-10-06 02:09:06 +01:00
commit 68d786b610
4 changed files with 185 additions and 34 deletions

8
NEWS
View File

@ -1,3 +1,11 @@
What's new in psycopg 2.9.4 (unreleased)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- Fix `register_composite()`, `register_range()` with customized search_path
(:ticket:`#1487`).
- Handle correctly composite types with names or in schemas requiring escape.
Current release Current release
--------------- ---------------

View File

@ -363,33 +363,54 @@ class RangeCaster:
schema = 'public' schema = 'public'
# get the type oid and attributes # get the type oid and attributes
try: curs.execute("""\
curs.execute("""\ select rngtypid, rngsubtype, typarray
select rngtypid, rngsubtype,
(select typarray from pg_type where oid = rngtypid)
from pg_range r from pg_range r
join pg_type t on t.oid = rngtypid join pg_type t on t.oid = rngtypid
join pg_namespace ns on ns.oid = typnamespace join pg_namespace ns on ns.oid = typnamespace
where typname = %s and ns.nspname = %s; where typname = %s and ns.nspname = %s;
""", (tname, schema)) """, (tname, schema))
rec = curs.fetchone()
except ProgrammingError: if not rec:
if not conn.autocommit: # The above algorithm doesn't work for customized seach_path
conn.rollback() # (#1487) The implementation below works better, but, to guarantee
raise # backwards compatibility, use it only if the original one failed.
else: try:
rec = curs.fetchone() savepoint = False
# Because we executed statements earlier, we are either INTRANS
# or we are IDLE only if the transaction is autocommit, in
# which case we don't need the savepoint anyway.
if conn.status == STATUS_IN_TRANSACTION:
curs.execute("SAVEPOINT register_type")
savepoint = True
# revert the status of the connection as before the command curs.execute("""\
if (conn_status != STATUS_IN_TRANSACTION SELECT rngtypid, rngsubtype, typarray, typname, nspname
and not conn.autocommit): from pg_range r
conn.rollback() join pg_type t on t.oid = rngtypid
join pg_namespace ns on ns.oid = typnamespace
WHERE t.oid = %s::regtype
""", (name, ))
except ProgrammingError:
pass
else:
rec = curs.fetchone()
if rec:
tname, schema = rec[3:]
finally:
if savepoint:
curs.execute("ROLLBACK TO SAVEPOINT register_type")
# revert the status of the connection as before the command
if conn_status != STATUS_IN_TRANSACTION and not conn.autocommit:
conn.rollback()
if not rec: if not rec:
raise ProgrammingError( raise ProgrammingError(
f"PostgreSQL type '{name}' not found") f"PostgreSQL range '{name}' not found")
type, subtype, array = rec type, subtype, array = rec[:3]
return RangeCaster(name, pyrange, return RangeCaster(name, pyrange,
oid=type, subtype_oid=subtype, array_oid=array) oid=type, subtype_oid=subtype, array_oid=array)

View File

@ -357,10 +357,6 @@ class NamedTupleCursor(_cursor):
except StopIteration: except StopIteration:
return return
# ascii except alnum and underscore
_re_clean = _re.compile(
'[' + _re.escape(' !"#$%&\'()*+,-./:;<=>?@[\\]^`{|}~') + ']')
def _make_nt(self): def _make_nt(self):
key = tuple(d[0] for d in self.description) if self.description else () key = tuple(d[0] for d in self.description) if self.description else ()
return self._cached_make_nt(key) return self._cached_make_nt(key)
@ -369,7 +365,7 @@ class NamedTupleCursor(_cursor):
def _do_make_nt(cls, key): def _do_make_nt(cls, key):
fields = [] fields = []
for s in key: for s in key:
s = cls._re_clean.sub('_', s) s = _re_clean.sub('_', s)
# Python identifier cannot start with numbers, namedtuple fields # Python identifier cannot start with numbers, namedtuple fields
# cannot start with underscore. So... # cannot start with underscore. So...
if s[0] == '_' or '0' <= s[0] <= '9': if s[0] == '_' or '0' <= s[0] <= '9':
@ -1061,6 +1057,7 @@ class CompositeCaster:
return rv return rv
def _create_type(self, name, attnames): def _create_type(self, name, attnames):
name = _re_clean.sub('_', name)
self.type = namedtuple(name, attnames) self.type = namedtuple(name, attnames)
self._ctor = self.type._make self._ctor = self.type._make
@ -1098,9 +1095,41 @@ ORDER BY attnum;
recs = curs.fetchall() recs = curs.fetchall()
if not recs:
# The above algorithm doesn't work for customized seach_path
# (#1487) The implementation below works better, but, to guarantee
# backwards compatibility, use it only if the original one failed.
try:
savepoint = False
# Because we executed statements earlier, we are either INTRANS
# or we are IDLE only if the transaction is autocommit, in
# which case we don't need the savepoint anyway.
if conn.status == _ext.STATUS_IN_TRANSACTION:
curs.execute("SAVEPOINT register_type")
savepoint = True
curs.execute("""\
SELECT t.oid, %s, attname, atttypid, typname, nspname
FROM pg_type t
JOIN pg_namespace ns ON typnamespace = ns.oid
JOIN pg_attribute a ON attrelid = typrelid
WHERE t.oid = %%s::regtype
AND attnum > 0 AND NOT attisdropped
ORDER BY attnum;
""" % typarray, (name, ))
except psycopg2.ProgrammingError:
pass
else:
recs = curs.fetchall()
if recs:
tname = recs[0][4]
schema = recs[0][5]
finally:
if savepoint:
curs.execute("ROLLBACK TO SAVEPOINT register_type")
# revert the status of the connection as before the command # revert the status of the connection as before the command
if (conn_status != _ext.STATUS_IN_TRANSACTION if conn_status != _ext.STATUS_IN_TRANSACTION and not conn.autocommit:
and not conn.autocommit):
conn.rollback() conn.rollback()
if not recs: if not recs:
@ -1304,3 +1333,8 @@ def _split_sql(sql):
raise ValueError("the query doesn't contain any '%s' placeholder") raise ValueError("the query doesn't contain any '%s' placeholder")
return pre, post return pre, post
# ascii except alnum and underscore
_re_clean = _re.compile(
'[' + _re.escape(' !"#$%&\'()*+,-./:;<=>?@[\\]^`{|}~') + ']')

View File

@ -584,6 +584,68 @@ class AdaptTypeTestCase(ConnectingTestCase):
curs.execute("select (4,8)::typens.typens_ii") curs.execute("select (4,8)::typens.typens_ii")
self.assertEqual(curs.fetchone()[0], (4, 8)) self.assertEqual(curs.fetchone()[0], (4, 8))
@skip_if_no_composite
def test_composite_namespace_path(self):
curs = self.conn.cursor()
curs.execute("""
select nspname from pg_namespace
where nspname = 'typens';
""")
if not curs.fetchone():
curs.execute("create schema typens;")
self.conn.commit()
self._create_type("typens.typensp_ii",
[("a", "integer"), ("b", "integer")])
curs.execute("set search_path=typens,public")
t = psycopg2.extras.register_composite(
"typensp_ii", self.conn)
self.assertEqual(t.schema, 'typens')
curs.execute("select (4,8)::typensp_ii")
self.assertEqual(curs.fetchone()[0], (4, 8))
@skip_if_no_composite
def test_composite_weird_name(self):
curs = self.conn.cursor()
curs.execute("""
select nspname from pg_namespace
where nspname = 'qux.quux';
""")
if not curs.fetchone():
curs.execute('create schema "qux.quux";')
self._create_type('"qux.quux"."foo.bar"',
[("a", "integer"), ("b", "integer")])
t = psycopg2.extras.register_composite(
'"qux.quux"."foo.bar"', self.conn)
self.assertEqual(t.name, 'foo.bar')
self.assertEqual(t.schema, 'qux.quux')
curs.execute('select (4,8)::"qux.quux"."foo.bar"')
self.assertEqual(curs.fetchone()[0], (4, 8))
@skip_if_no_composite
def test_composite_not_found(self):
self.assertRaises(
psycopg2.ProgrammingError, psycopg2.extras.register_composite,
"nosuchtype", self.conn)
self.assertEqual(self.conn.status, ext.STATUS_READY)
cur = self.conn.cursor()
cur.execute("select 1")
self.assertRaises(
psycopg2.ProgrammingError, psycopg2.extras.register_composite,
"nosuchtype", self.conn)
self.assertEqual(self.conn.status, ext.STATUS_IN_TRANSACTION)
self.conn.rollback()
self.conn.autocommit = True
self.assertRaises(
psycopg2.ProgrammingError, psycopg2.extras.register_composite,
"nosuchtype", self.conn)
self.assertEqual(self.conn.status, ext.STATUS_READY)
@skip_if_no_composite @skip_if_no_composite
@skip_before_postgres(8, 4) @skip_before_postgres(8, 4)
def test_composite_array(self): def test_composite_array(self):
@ -710,22 +772,15 @@ class AdaptTypeTestCase(ConnectingTestCase):
def _create_type(self, name, fields): def _create_type(self, name, fields):
curs = self.conn.cursor() curs = self.conn.cursor()
try: try:
curs.execute("savepoint x")
curs.execute(f"drop type {name} cascade;") curs.execute(f"drop type {name} cascade;")
except psycopg2.ProgrammingError: except psycopg2.ProgrammingError:
self.conn.rollback() curs.execute("rollback to savepoint x")
curs.execute("create type {} as ({});".format(name, curs.execute("create type {} as ({});".format(name,
", ".join(["%s %s" % p for p in fields]))) ", ".join(["%s %s" % p for p in fields])))
if '.' in name:
schema, name = name.split('.')
else:
schema = 'public'
curs.execute("""\ curs.execute("SELECT %s::regtype::oid", (name, ))
SELECT t.oid
FROM pg_type t JOIN pg_namespace ns ON typnamespace = ns.oid
WHERE typname = %s and nspname = %s;
""", (name, schema))
oid = curs.fetchone()[0] oid = curs.fetchone()[0]
self.conn.commit() self.conn.commit()
return oid return oid
@ -1560,6 +1615,18 @@ class RangeCasterTestCase(ConnectingTestCase):
cur = self.conn.cursor() cur = self.conn.cursor()
self.assertRaises(psycopg2.ProgrammingError, self.assertRaises(psycopg2.ProgrammingError,
register_range, 'nosuchrange', 'FailRange', cur) register_range, 'nosuchrange', 'FailRange', cur)
self.assertEqual(self.conn.status, ext.STATUS_READY)
cur.execute("select 1")
self.assertRaises(psycopg2.ProgrammingError,
register_range, 'nosuchrange', 'FailRange', cur)
self.assertEqual(self.conn.status, ext.STATUS_IN_TRANSACTION)
self.conn.rollback()
self.conn.autocommit = True
self.assertRaises(psycopg2.ProgrammingError,
register_range, 'nosuchrange', 'FailRange', cur)
@restore_types @restore_types
def test_schema_range(self): def test_schema_range(self):
@ -1574,7 +1641,7 @@ class RangeCasterTestCase(ConnectingTestCase):
register_range('r1', 'r1', cur) register_range('r1', 'r1', cur)
ra2 = register_range('r2', 'r2', cur) ra2 = register_range('r2', 'r2', cur)
rars2 = register_range('rs.r2', 'r2', cur) rars2 = register_range('rs.r2', 'r2', cur)
register_range('rs.r3', 'r3', cur) rars3 = register_range('rs.r3', 'r3', cur)
self.assertNotEqual( self.assertNotEqual(
ra2.typecaster.values[0], ra2.typecaster.values[0],
@ -1588,6 +1655,27 @@ class RangeCasterTestCase(ConnectingTestCase):
register_range, 'rs.r1', 'FailRange', cur) register_range, 'rs.r1', 'FailRange', cur)
cur.execute("rollback to savepoint x;") cur.execute("rollback to savepoint x;")
cur2 = self.conn.cursor()
cur2.execute("set local search_path to rs,public")
ra3 = register_range('r3', 'r3', cur2)
self.assertEqual(ra3.typecaster.values[0], rars3.typecaster.values[0])
@skip_if_no_composite
def test_rang_weird_name(self):
cur = self.conn.cursor()
cur.execute("""
select nspname from pg_namespace
where nspname = 'qux.quux';
""")
if not cur.fetchone():
cur.execute('create schema "qux.quux";')
cur.execute('create type "qux.quux"."foo.range" as range (subtype=text)')
r = psycopg2.extras.register_range(
'"qux.quux"."foo.range"', "foorange", cur)
cur.execute('''select '[a,z]'::"qux.quux"."foo.range"''')
self.assertEqual(cur.fetchone()[0], r.range('a', 'z', '[]'))
def test_suite(): def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__) return unittest.TestLoader().loadTestsFromName(__name__)