mirror of
				https://github.com/psycopg/psycopg2.git
				synced 2025-11-04 01:37:31 +03:00 
			
		
		
		
	Escape parameters to the connection strings as required by PQconnectdb
This commit is contained in:
		
							parent
							
								
									99ac79511c
								
							
						
					
					
						commit
						625cc1b402
					
				| 
						 | 
				
			
			@ -97,6 +97,24 @@ else:
 | 
			
		|||
    _ext.register_adapter(Decimal, Adapter)
 | 
			
		||||
    del Decimal, Adapter
 | 
			
		||||
 | 
			
		||||
import re
 | 
			
		||||
 | 
			
		||||
def _param_escape(s,
 | 
			
		||||
        re_escape=re.compile(r"([\\'])"),
 | 
			
		||||
        re_space=re.compile(r'\s')):
 | 
			
		||||
    """
 | 
			
		||||
    Apply the escaping rule required by PQconnectdb
 | 
			
		||||
    """
 | 
			
		||||
    if not s: return "''"
 | 
			
		||||
 | 
			
		||||
    s = re_escape.sub(r'\\\1', s)
 | 
			
		||||
    if re_space.search(s):
 | 
			
		||||
        s = "'" + s + "'"
 | 
			
		||||
 | 
			
		||||
    return s
 | 
			
		||||
 | 
			
		||||
del re
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def connect(dsn=None,
 | 
			
		||||
        database=None, user=None, password=None, host=None, port=None,
 | 
			
		||||
| 
						 | 
				
			
			@ -147,7 +165,8 @@ def connect(dsn=None,
 | 
			
		|||
 | 
			
		||||
        items.extend(
 | 
			
		||||
            [(k, v) for (k, v) in kwargs.iteritems() if v is not None])
 | 
			
		||||
        dsn = " ".join(["%s=%s" % item for item in items])
 | 
			
		||||
        dsn = " ".join(["%s=%s" % (k, _param_escape(str(v)))
 | 
			
		||||
            for (k, v) in items])
 | 
			
		||||
 | 
			
		||||
        if not dsn:
 | 
			
		||||
            raise InterfaceError('missing dsn and no parameters')
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -99,6 +99,26 @@ class ConnectTestCase(unittest.TestCase):
 | 
			
		|||
        self.assertEqual(self.args[1], None)
 | 
			
		||||
        self.assert_(self.args[2])
 | 
			
		||||
 | 
			
		||||
    def test_empty_param(self):
 | 
			
		||||
        psycopg2.connect(database='sony', password='')
 | 
			
		||||
        self.assertEqual(self.args[0], "dbname=sony password=''")
 | 
			
		||||
 | 
			
		||||
    def test_escape(self):
 | 
			
		||||
        psycopg2.connect(database='hello world')
 | 
			
		||||
        self.assertEqual(self.args[0], "dbname='hello world'")
 | 
			
		||||
 | 
			
		||||
        psycopg2.connect(database=r'back\slash')
 | 
			
		||||
        self.assertEqual(self.args[0], r"dbname=back\\slash")
 | 
			
		||||
 | 
			
		||||
        psycopg2.connect(database="quo'te")
 | 
			
		||||
        self.assertEqual(self.args[0], r"dbname=quo\'te")
 | 
			
		||||
 | 
			
		||||
        psycopg2.connect(database="with\ttab")
 | 
			
		||||
        self.assertEqual(self.args[0], "dbname='with\ttab'")
 | 
			
		||||
 | 
			
		||||
        psycopg2.connect(database=r"\every thing'")
 | 
			
		||||
        self.assertEqual(self.args[0], r"dbname='\\every thing\''")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_suite():
 | 
			
		||||
    return unittest.TestLoader().loadTestsFromName(__name__)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user