mirror of
				https://github.com/psycopg/psycopg2.git
				synced 2025-11-04 09:47:30 +03:00 
			
		
		
		
	fix: look up for range types defined in schemas in the search path
This commit is contained in:
		
							parent
							
								
									9535462ce9
								
							
						
					
					
						commit
						ac25d3bdc0
					
				
							
								
								
									
										3
									
								
								NEWS
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								NEWS
									
									
									
									
									
								
							| 
						 | 
					@ -1,7 +1,8 @@
 | 
				
			||||||
What's new in psycopg 2.9.4 (unreleased)
 | 
					What's new in psycopg 2.9.4 (unreleased)
 | 
				
			||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
					^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- Fix `register_composite()` with customized search_path (:ticket:`#1487`).
 | 
					- Fix `register_composite()`, `register_range()` with customized search_path
 | 
				
			||||||
 | 
					  (:ticket:`#1487`).
 | 
				
			||||||
- Handle correctly composite types with names or in schemas requiring escape.
 | 
					- Handle correctly composite types with names or in schemas requiring escape.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -363,33 +363,54 @@ class RangeCaster:
 | 
				
			||||||
            schema = 'public'
 | 
					            schema = 'public'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # get the type oid and attributes
 | 
					        # get the type oid and attributes
 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
        curs.execute("""\
 | 
					        curs.execute("""\
 | 
				
			||||||
select rngtypid, rngsubtype,
 | 
					select rngtypid, rngsubtype, typarray
 | 
				
			||||||
    (select typarray from pg_type where oid = rngtypid)
 | 
					 | 
				
			||||||
from pg_range r
 | 
					from pg_range r
 | 
				
			||||||
join pg_type t on t.oid = rngtypid
 | 
					join pg_type t on t.oid = rngtypid
 | 
				
			||||||
join pg_namespace ns on ns.oid = typnamespace
 | 
					join pg_namespace ns on ns.oid = typnamespace
 | 
				
			||||||
where typname = %s and ns.nspname = %s;
 | 
					where typname = %s and ns.nspname = %s;
 | 
				
			||||||
""", (tname, schema))
 | 
					""", (tname, schema))
 | 
				
			||||||
 | 
					 | 
				
			||||||
        except ProgrammingError:
 | 
					 | 
				
			||||||
            if not conn.autocommit:
 | 
					 | 
				
			||||||
                conn.rollback()
 | 
					 | 
				
			||||||
            raise
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
        rec = curs.fetchone()
 | 
					        rec = curs.fetchone()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not rec:
 | 
				
			||||||
 | 
					            # The above algorithm doesn't work for customized seach_path
 | 
				
			||||||
 | 
					            # (#1487) The implementation below works better, but, to guarantee
 | 
				
			||||||
 | 
					            # backwards compatibility, use it only if the original one failed.
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                savepoint = False
 | 
				
			||||||
 | 
					                # Because we executed statements earlier, we are either INTRANS
 | 
				
			||||||
 | 
					                # or we are IDLE only if the transaction is autocommit, in
 | 
				
			||||||
 | 
					                # which case we don't need the savepoint anyway.
 | 
				
			||||||
 | 
					                if conn.status == STATUS_IN_TRANSACTION:
 | 
				
			||||||
 | 
					                    curs.execute("SAVEPOINT register_type")
 | 
				
			||||||
 | 
					                    savepoint = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                curs.execute("""\
 | 
				
			||||||
 | 
					SELECT rngtypid, rngsubtype, typarray, typname, nspname
 | 
				
			||||||
 | 
					from pg_range r
 | 
				
			||||||
 | 
					join pg_type t on t.oid = rngtypid
 | 
				
			||||||
 | 
					join pg_namespace ns on ns.oid = typnamespace
 | 
				
			||||||
 | 
					WHERE t.oid = %s::regtype
 | 
				
			||||||
 | 
					""", (name, ))
 | 
				
			||||||
 | 
					            except ProgrammingError:
 | 
				
			||||||
 | 
					                pass
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                rec = curs.fetchone()
 | 
				
			||||||
 | 
					                if rec:
 | 
				
			||||||
 | 
					                    tname, schema = rec[3:]
 | 
				
			||||||
 | 
					            finally:
 | 
				
			||||||
 | 
					                if savepoint:
 | 
				
			||||||
 | 
					                    curs.execute("ROLLBACK TO SAVEPOINT register_type")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # revert the status of the connection as before the command
 | 
					        # revert the status of the connection as before the command
 | 
				
			||||||
            if (conn_status != STATUS_IN_TRANSACTION
 | 
					        if conn_status != STATUS_IN_TRANSACTION and not conn.autocommit:
 | 
				
			||||||
            and not conn.autocommit):
 | 
					 | 
				
			||||||
            conn.rollback()
 | 
					            conn.rollback()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not rec:
 | 
					        if not rec:
 | 
				
			||||||
            raise ProgrammingError(
 | 
					            raise ProgrammingError(
 | 
				
			||||||
                f"PostgreSQL type '{name}' not found")
 | 
					                f"PostgreSQL range '{name}' not found")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        type, subtype, array = rec
 | 
					        type, subtype, array = rec[:3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return RangeCaster(name, pyrange,
 | 
					        return RangeCaster(name, pyrange,
 | 
				
			||||||
            oid=type, subtype_oid=subtype, array_oid=array)
 | 
					            oid=type, subtype_oid=subtype, array_oid=array)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1615,6 +1615,18 @@ class RangeCasterTestCase(ConnectingTestCase):
 | 
				
			||||||
        cur = self.conn.cursor()
 | 
					        cur = self.conn.cursor()
 | 
				
			||||||
        self.assertRaises(psycopg2.ProgrammingError,
 | 
					        self.assertRaises(psycopg2.ProgrammingError,
 | 
				
			||||||
            register_range, 'nosuchrange', 'FailRange', cur)
 | 
					            register_range, 'nosuchrange', 'FailRange', cur)
 | 
				
			||||||
 | 
					        self.assertEqual(self.conn.status, ext.STATUS_READY)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cur.execute("select 1")
 | 
				
			||||||
 | 
					        self.assertRaises(psycopg2.ProgrammingError,
 | 
				
			||||||
 | 
					            register_range, 'nosuchrange', 'FailRange', cur)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(self.conn.status, ext.STATUS_IN_TRANSACTION)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.conn.rollback()
 | 
				
			||||||
 | 
					        self.conn.autocommit = True
 | 
				
			||||||
 | 
					        self.assertRaises(psycopg2.ProgrammingError,
 | 
				
			||||||
 | 
					            register_range, 'nosuchrange', 'FailRange', cur)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @restore_types
 | 
					    @restore_types
 | 
				
			||||||
    def test_schema_range(self):
 | 
					    def test_schema_range(self):
 | 
				
			||||||
| 
						 | 
					@ -1629,7 +1641,7 @@ class RangeCasterTestCase(ConnectingTestCase):
 | 
				
			||||||
        register_range('r1', 'r1', cur)
 | 
					        register_range('r1', 'r1', cur)
 | 
				
			||||||
        ra2 = register_range('r2', 'r2', cur)
 | 
					        ra2 = register_range('r2', 'r2', cur)
 | 
				
			||||||
        rars2 = register_range('rs.r2', 'r2', cur)
 | 
					        rars2 = register_range('rs.r2', 'r2', cur)
 | 
				
			||||||
        register_range('rs.r3', 'r3', cur)
 | 
					        rars3 = register_range('rs.r3', 'r3', cur)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.assertNotEqual(
 | 
					        self.assertNotEqual(
 | 
				
			||||||
            ra2.typecaster.values[0],
 | 
					            ra2.typecaster.values[0],
 | 
				
			||||||
| 
						 | 
					@ -1643,6 +1655,11 @@ class RangeCasterTestCase(ConnectingTestCase):
 | 
				
			||||||
            register_range, 'rs.r1', 'FailRange', cur)
 | 
					            register_range, 'rs.r1', 'FailRange', cur)
 | 
				
			||||||
        cur.execute("rollback to savepoint x;")
 | 
					        cur.execute("rollback to savepoint x;")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cur2 = self.conn.cursor()
 | 
				
			||||||
 | 
					        cur2.execute("set local search_path to rs,public")
 | 
				
			||||||
 | 
					        ra3 = register_range('r3', 'r3', cur2)
 | 
				
			||||||
 | 
					        self.assertEqual(ra3.typecaster.values[0], rars3.typecaster.values[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_suite():
 | 
					def test_suite():
 | 
				
			||||||
    return unittest.TestLoader().loadTestsFromName(__name__)
 | 
					    return unittest.TestLoader().loadTestsFromName(__name__)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user