From 9535462ce9b4a8f9fe7b680e5f82add59cb51161 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Thu, 6 Oct 2022 01:27:10 +0100 Subject: [PATCH] fix: correctly handle composites with names or schema requiring escape --- NEWS | 1 + lib/extras.py | 17 ++++++++++------- tests/test_types_extras.py | 32 ++++++++++++++++++++++---------- 3 files changed, 33 insertions(+), 17 deletions(-) diff --git a/NEWS b/NEWS index d599cd66..78f1b8e7 100644 --- a/NEWS +++ b/NEWS @@ -2,6 +2,7 @@ What's new in psycopg 2.9.4 (unreleased) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - Fix `register_composite()` with customized search_path (:ticket:`#1487`). +- Handle correctly composite types with names or in schemas requiring escape. Current release diff --git a/lib/extras.py b/lib/extras.py index 5d6f20e7..36e8ef9a 100644 --- a/lib/extras.py +++ b/lib/extras.py @@ -357,10 +357,6 @@ class NamedTupleCursor(_cursor): except StopIteration: return - # ascii except alnum and underscore - _re_clean = _re.compile( - '[' + _re.escape(' !"#$%&\'()*+,-./:;<=>?@[\\]^`{|}~') + ']') - def _make_nt(self): key = tuple(d[0] for d in self.description) if self.description else () return self._cached_make_nt(key) @@ -369,7 +365,7 @@ class NamedTupleCursor(_cursor): def _do_make_nt(cls, key): fields = [] for s in key: - s = cls._re_clean.sub('_', s) + s = _re_clean.sub('_', s) # Python identifier cannot start with numbers, namedtuple fields # cannot start with underscore. So... if s[0] == '_' or '0' <= s[0] <= '9': @@ -1061,6 +1057,7 @@ class CompositeCaster: return rv def _create_type(self, name, attnames): + name = _re_clean.sub('_', name) self.type = namedtuple(name, attnames) self._ctor = self.type._make @@ -1112,7 +1109,7 @@ ORDER BY attnum; savepoint = True curs.execute("""\ -SELECT t.oid, %s, attname, atttypid, nspname +SELECT t.oid, %s, attname, atttypid, typname, nspname FROM pg_type t JOIN pg_namespace ns ON typnamespace = ns.oid JOIN pg_attribute a ON attrelid = typrelid @@ -1125,7 +1122,8 @@ ORDER BY attnum; else: recs = curs.fetchall() if recs: - schema = recs[0][4] + tname = recs[0][4] + schema = recs[0][5] finally: if savepoint: curs.execute("ROLLBACK TO SAVEPOINT register_type") @@ -1335,3 +1333,8 @@ def _split_sql(sql): raise ValueError("the query doesn't contain any '%s' placeholder") return pre, post + + +# ascii except alnum and underscore +_re_clean = _re.compile( + '[' + _re.escape(' !"#$%&\'()*+,-./:;<=>?@[\\]^`{|}~') + ']') diff --git a/tests/test_types_extras.py b/tests/test_types_extras.py index 640c8615..c50c59ff 100755 --- a/tests/test_types_extras.py +++ b/tests/test_types_extras.py @@ -604,6 +604,25 @@ class AdaptTypeTestCase(ConnectingTestCase): curs.execute("select (4,8)::typensp_ii") self.assertEqual(curs.fetchone()[0], (4, 8)) + @skip_if_no_composite + def test_composite_weird_name(self): + curs = self.conn.cursor() + curs.execute(""" + select nspname from pg_namespace + where nspname = 'qux.quux'; + """) + if not curs.fetchone(): + curs.execute('create schema "qux.quux";') + + self._create_type('"qux.quux"."foo.bar"', + [("a", "integer"), ("b", "integer")]) + t = psycopg2.extras.register_composite( + '"qux.quux"."foo.bar"', self.conn) + self.assertEqual(t.name, 'foo.bar') + self.assertEqual(t.schema, 'qux.quux') + curs.execute('select (4,8)::"qux.quux"."foo.bar"') + self.assertEqual(curs.fetchone()[0], (4, 8)) + @skip_if_no_composite def test_composite_not_found(self): @@ -753,22 +772,15 @@ class AdaptTypeTestCase(ConnectingTestCase): def _create_type(self, name, fields): curs = self.conn.cursor() try: + curs.execute("savepoint x") curs.execute(f"drop type {name} cascade;") except psycopg2.ProgrammingError: - self.conn.rollback() + curs.execute("rollback to savepoint x") curs.execute("create type {} as ({});".format(name, ", ".join(["%s %s" % p for p in fields]))) - if '.' in name: - schema, name = name.split('.') - else: - schema = 'public' - curs.execute("""\ - SELECT t.oid - FROM pg_type t JOIN pg_namespace ns ON typnamespace = ns.oid - WHERE typname = %s and nspname = %s; - """, (name, schema)) + curs.execute("SELECT %s::regtype::oid", (name, )) oid = curs.fetchone()[0] self.conn.commit() return oid