From 625cc1b402b33799757fb9b8fe421a2ea63e1236 Mon Sep 17 00:00:00 2001
From: Daniele Varrazzo <daniele.varrazzo@gmail.com>
Date: Thu, 17 Nov 2011 21:19:01 +0000
Subject: [PATCH] Escape parameters to the connection strings as required by
 PQconnectdb

---
 lib/__init__.py      | 21 ++++++++++++++++++++-
 tests/test_module.py | 20 ++++++++++++++++++++
 2 files changed, 40 insertions(+), 1 deletion(-)

diff --git a/lib/__init__.py b/lib/__init__.py
index 7676f3c3..f42d081c 100644
--- a/lib/__init__.py
+++ b/lib/__init__.py
@@ -97,6 +97,24 @@ else:
     _ext.register_adapter(Decimal, Adapter)
     del Decimal, Adapter
 
+import re
+
+def _param_escape(s,
+        re_escape=re.compile(r"([\\'])"),
+        re_space=re.compile(r'\s')):
+    """
+    Apply the escaping rule required by PQconnectdb
+    """
+    if not s: return "''"
+
+    s = re_escape.sub(r'\\\1', s)
+    if re_space.search(s):
+        s = "'" + s + "'"
+
+    return s
+
+del re
+
 
 def connect(dsn=None,
         database=None, user=None, password=None, host=None, port=None,
@@ -147,7 +165,8 @@ def connect(dsn=None,
 
         items.extend(
             [(k, v) for (k, v) in kwargs.iteritems() if v is not None])
-        dsn = " ".join(["%s=%s" % item for item in items])
+        dsn = " ".join(["%s=%s" % (k, _param_escape(str(v)))
+            for (k, v) in items])
 
         if not dsn:
             raise InterfaceError('missing dsn and no parameters')
diff --git a/tests/test_module.py b/tests/test_module.py
index 66eeccf8..5d45187a 100755
--- a/tests/test_module.py
+++ b/tests/test_module.py
@@ -99,6 +99,26 @@ class ConnectTestCase(unittest.TestCase):
         self.assertEqual(self.args[1], None)
         self.assert_(self.args[2])
 
+    def test_empty_param(self):
+        psycopg2.connect(database='sony', password='')
+        self.assertEqual(self.args[0], "dbname=sony password=''")
+
+    def test_escape(self):
+        psycopg2.connect(database='hello world')
+        self.assertEqual(self.args[0], "dbname='hello world'")
+
+        psycopg2.connect(database=r'back\slash')
+        self.assertEqual(self.args[0], r"dbname=back\\slash")
+
+        psycopg2.connect(database="quo'te")
+        self.assertEqual(self.args[0], r"dbname=quo\'te")
+
+        psycopg2.connect(database="with\ttab")
+        self.assertEqual(self.args[0], "dbname='with\ttab'")
+
+        psycopg2.connect(database=r"\every thing'")
+        self.assertEqual(self.args[0], r"dbname='\\every thing\''")
+
 
 def test_suite():
     return unittest.TestLoader().loadTestsFromName(__name__)