diff --git a/lib/__init__.py b/lib/__init__.py index 7676f3c3..f42d081c 100644 --- a/lib/__init__.py +++ b/lib/__init__.py @@ -97,6 +97,24 @@ else: _ext.register_adapter(Decimal, Adapter) del Decimal, Adapter +import re + +def _param_escape(s, + re_escape=re.compile(r"([\\'])"), + re_space=re.compile(r'\s')): + """ + Apply the escaping rule required by PQconnectdb + """ + if not s: return "''" + + s = re_escape.sub(r'\\\1', s) + if re_space.search(s): + s = "'" + s + "'" + + return s + +del re + def connect(dsn=None, database=None, user=None, password=None, host=None, port=None, @@ -147,7 +165,8 @@ def connect(dsn=None, items.extend( [(k, v) for (k, v) in kwargs.iteritems() if v is not None]) - dsn = " ".join(["%s=%s" % item for item in items]) + dsn = " ".join(["%s=%s" % (k, _param_escape(str(v))) + for (k, v) in items]) if not dsn: raise InterfaceError('missing dsn and no parameters') diff --git a/tests/test_module.py b/tests/test_module.py index 66eeccf8..5d45187a 100755 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -99,6 +99,26 @@ class ConnectTestCase(unittest.TestCase): self.assertEqual(self.args[1], None) self.assert_(self.args[2]) + def test_empty_param(self): + psycopg2.connect(database='sony', password='') + self.assertEqual(self.args[0], "dbname=sony password=''") + + def test_escape(self): + psycopg2.connect(database='hello world') + self.assertEqual(self.args[0], "dbname='hello world'") + + psycopg2.connect(database=r'back\slash') + self.assertEqual(self.args[0], r"dbname=back\\slash") + + psycopg2.connect(database="quo'te") + self.assertEqual(self.args[0], r"dbname=quo\'te") + + psycopg2.connect(database="with\ttab") + self.assertEqual(self.args[0], "dbname='with\ttab'") + + psycopg2.connect(database=r"\every thing'") + self.assertEqual(self.args[0], r"dbname='\\every thing\''") + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__)