fix: correctly handle composites with names or schema requiring escape

This commit is contained in:
Daniele Varrazzo 2022-10-06 01:27:10 +01:00
parent d88e4c2a3c
commit 9535462ce9
3 changed files with 33 additions and 17 deletions

1
NEWS
View File

@ -2,6 +2,7 @@ What's new in psycopg 2.9.4 (unreleased)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- Fix `register_composite()` with customized search_path (:ticket:`#1487`). - Fix `register_composite()` with customized search_path (:ticket:`#1487`).
- Handle correctly composite types with names or in schemas requiring escape.
Current release Current release

View File

@ -357,10 +357,6 @@ class NamedTupleCursor(_cursor):
except StopIteration: except StopIteration:
return return
# ascii except alnum and underscore
_re_clean = _re.compile(
'[' + _re.escape(' !"#$%&\'()*+,-./:;<=>?@[\\]^`{|}~') + ']')
def _make_nt(self): def _make_nt(self):
key = tuple(d[0] for d in self.description) if self.description else () key = tuple(d[0] for d in self.description) if self.description else ()
return self._cached_make_nt(key) return self._cached_make_nt(key)
@ -369,7 +365,7 @@ class NamedTupleCursor(_cursor):
def _do_make_nt(cls, key): def _do_make_nt(cls, key):
fields = [] fields = []
for s in key: for s in key:
s = cls._re_clean.sub('_', s) s = _re_clean.sub('_', s)
# Python identifier cannot start with numbers, namedtuple fields # Python identifier cannot start with numbers, namedtuple fields
# cannot start with underscore. So... # cannot start with underscore. So...
if s[0] == '_' or '0' <= s[0] <= '9': if s[0] == '_' or '0' <= s[0] <= '9':
@ -1061,6 +1057,7 @@ class CompositeCaster:
return rv return rv
def _create_type(self, name, attnames): def _create_type(self, name, attnames):
name = _re_clean.sub('_', name)
self.type = namedtuple(name, attnames) self.type = namedtuple(name, attnames)
self._ctor = self.type._make self._ctor = self.type._make
@ -1112,7 +1109,7 @@ ORDER BY attnum;
savepoint = True savepoint = True
curs.execute("""\ curs.execute("""\
SELECT t.oid, %s, attname, atttypid, nspname SELECT t.oid, %s, attname, atttypid, typname, nspname
FROM pg_type t FROM pg_type t
JOIN pg_namespace ns ON typnamespace = ns.oid JOIN pg_namespace ns ON typnamespace = ns.oid
JOIN pg_attribute a ON attrelid = typrelid JOIN pg_attribute a ON attrelid = typrelid
@ -1125,7 +1122,8 @@ ORDER BY attnum;
else: else:
recs = curs.fetchall() recs = curs.fetchall()
if recs: if recs:
schema = recs[0][4] tname = recs[0][4]
schema = recs[0][5]
finally: finally:
if savepoint: if savepoint:
curs.execute("ROLLBACK TO SAVEPOINT register_type") curs.execute("ROLLBACK TO SAVEPOINT register_type")
@ -1335,3 +1333,8 @@ def _split_sql(sql):
raise ValueError("the query doesn't contain any '%s' placeholder") raise ValueError("the query doesn't contain any '%s' placeholder")
return pre, post return pre, post
# ascii except alnum and underscore
_re_clean = _re.compile(
'[' + _re.escape(' !"#$%&\'()*+,-./:;<=>?@[\\]^`{|}~') + ']')

View File

@ -604,6 +604,25 @@ class AdaptTypeTestCase(ConnectingTestCase):
curs.execute("select (4,8)::typensp_ii") curs.execute("select (4,8)::typensp_ii")
self.assertEqual(curs.fetchone()[0], (4, 8)) self.assertEqual(curs.fetchone()[0], (4, 8))
@skip_if_no_composite
def test_composite_weird_name(self):
curs = self.conn.cursor()
curs.execute("""
select nspname from pg_namespace
where nspname = 'qux.quux';
""")
if not curs.fetchone():
curs.execute('create schema "qux.quux";')
self._create_type('"qux.quux"."foo.bar"',
[("a", "integer"), ("b", "integer")])
t = psycopg2.extras.register_composite(
'"qux.quux"."foo.bar"', self.conn)
self.assertEqual(t.name, 'foo.bar')
self.assertEqual(t.schema, 'qux.quux')
curs.execute('select (4,8)::"qux.quux"."foo.bar"')
self.assertEqual(curs.fetchone()[0], (4, 8))
@skip_if_no_composite @skip_if_no_composite
def test_composite_not_found(self): def test_composite_not_found(self):
@ -753,22 +772,15 @@ class AdaptTypeTestCase(ConnectingTestCase):
def _create_type(self, name, fields): def _create_type(self, name, fields):
curs = self.conn.cursor() curs = self.conn.cursor()
try: try:
curs.execute("savepoint x")
curs.execute(f"drop type {name} cascade;") curs.execute(f"drop type {name} cascade;")
except psycopg2.ProgrammingError: except psycopg2.ProgrammingError:
self.conn.rollback() curs.execute("rollback to savepoint x")
curs.execute("create type {} as ({});".format(name, curs.execute("create type {} as ({});".format(name,
", ".join(["%s %s" % p for p in fields]))) ", ".join(["%s %s" % p for p in fields])))
if '.' in name:
schema, name = name.split('.')
else:
schema = 'public'
curs.execute("""\ curs.execute("SELECT %s::regtype::oid", (name, ))
SELECT t.oid
FROM pg_type t JOIN pg_namespace ns ON typnamespace = ns.oid
WHERE typname = %s and nspname = %s;
""", (name, schema))
oid = curs.fetchone()[0] oid = curs.fetchone()[0]
self.conn.commit() self.conn.commit()
return oid return oid