Escape parameters to the connection strings as required by PQconnectdb

This commit is contained in:
Daniele Varrazzo 2011-11-17 21:19:01 +00:00
parent 99ac79511c
commit 625cc1b402
2 changed files with 40 additions and 1 deletions

View File

@ -97,6 +97,24 @@ else:
_ext.register_adapter(Decimal, Adapter) _ext.register_adapter(Decimal, Adapter)
del Decimal, Adapter del Decimal, Adapter
import re
def _param_escape(s,
re_escape=re.compile(r"([\\'])"),
re_space=re.compile(r'\s')):
"""
Apply the escaping rule required by PQconnectdb
"""
if not s: return "''"
s = re_escape.sub(r'\\\1', s)
if re_space.search(s):
s = "'" + s + "'"
return s
del re
def connect(dsn=None, def connect(dsn=None,
database=None, user=None, password=None, host=None, port=None, database=None, user=None, password=None, host=None, port=None,
@ -147,7 +165,8 @@ def connect(dsn=None,
items.extend( items.extend(
[(k, v) for (k, v) in kwargs.iteritems() if v is not None]) [(k, v) for (k, v) in kwargs.iteritems() if v is not None])
dsn = " ".join(["%s=%s" % item for item in items]) dsn = " ".join(["%s=%s" % (k, _param_escape(str(v)))
for (k, v) in items])
if not dsn: if not dsn:
raise InterfaceError('missing dsn and no parameters') raise InterfaceError('missing dsn and no parameters')

View File

@ -99,6 +99,26 @@ class ConnectTestCase(unittest.TestCase):
self.assertEqual(self.args[1], None) self.assertEqual(self.args[1], None)
self.assert_(self.args[2]) self.assert_(self.args[2])
def test_empty_param(self):
psycopg2.connect(database='sony', password='')
self.assertEqual(self.args[0], "dbname=sony password=''")
def test_escape(self):
psycopg2.connect(database='hello world')
self.assertEqual(self.args[0], "dbname='hello world'")
psycopg2.connect(database=r'back\slash')
self.assertEqual(self.args[0], r"dbname=back\\slash")
psycopg2.connect(database="quo'te")
self.assertEqual(self.args[0], r"dbname=quo\'te")
psycopg2.connect(database="with\ttab")
self.assertEqual(self.args[0], "dbname='with\ttab'")
psycopg2.connect(database=r"\every thing'")
self.assertEqual(self.args[0], r"dbname='\\every thing\''")
def test_suite(): def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__) return unittest.TestLoader().loadTestsFromName(__name__)