From 2c55a1bd5394ef82a49984fdce3c17ce956a9c9e Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Thu, 3 Mar 2016 15:07:38 +0000 Subject: [PATCH] Verify that the dsn is not manipulated by make_dsn if not necessary --- lib/__init__.py | 3 --- lib/extensions.py | 7 +++++++ tests/test_module.py | 37 +++++++++++++++++++++++++------------ 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/lib/__init__.py b/lib/__init__.py index 608b5d14..4a288197 100644 --- a/lib/__init__.py +++ b/lib/__init__.py @@ -116,9 +116,6 @@ def connect(dsn=None, connection_factory=None, cursor_factory=None, library: the list of supported parameters depends on the library version. """ - if dsn is None and not kwargs: - raise TypeError('missing dsn and no parameters') - dsn = _ext.make_dsn(dsn, **kwargs) conn = _connect(dsn, connection_factory=connection_factory, async=async) if cursor_factory is not None: diff --git a/lib/extensions.py b/lib/extensions.py index 3024b2fd..469f1932 100644 --- a/lib/extensions.py +++ b/lib/extensions.py @@ -157,6 +157,13 @@ class NoneAdapter(object): def make_dsn(dsn=None, **kwargs): """Convert a set of keywords into a connection strings.""" + if dsn is None and not kwargs: + raise TypeError('missing dsn and no parameters') + + # If no kwarg is specified don't mung the dsn + if not kwargs: + return dsn + # Override the dsn with the parameters if 'database' in kwargs: if 'dbname' in kwargs: diff --git a/tests/test_module.py b/tests/test_module.py index c0e4bf87..9f0adcc9 100755 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -31,9 +31,11 @@ from testutils import ConnectingTestCase, skip_copy_if_green, script_to_py3 import psycopg2 + class ConnectTestCase(unittest.TestCase): def setUp(self): self.args = None + def conect_stub(dsn, connection_factory=None, async=False): self.args = (dsn, connection_factory, async) @@ -60,8 +62,8 @@ class ConnectTestCase(unittest.TestCase): self.assertEqual(self.args[2], False) def test_dsn(self): - psycopg2.connect('dbname=blah application_name=y') - self.assertDsnEqual(self.args[0], 'dbname=blah application_name=y') + psycopg2.connect('dbname=blah x=y') + self.assertEqual(self.args[0], 'dbname=blah x=y') self.assertEqual(self.args[1], None) self.assertEqual(self.args[2], False) @@ -93,24 +95,24 @@ class ConnectTestCase(unittest.TestCase): def f(dsn, async=False): pass - psycopg2.connect(database='foo', application_name='baz', connection_factory=f) - self.assertDsnEqual(self.args[0], 'dbname=foo application_name=baz') + psycopg2.connect(database='foo', bar='baz', connection_factory=f) + self.assertDsnEqual(self.args[0], 'dbname=foo bar=baz') self.assertEqual(self.args[1], f) self.assertEqual(self.args[2], False) - psycopg2.connect("dbname=foo application_name=baz", connection_factory=f) - self.assertDsnEqual(self.args[0], 'dbname=foo application_name=baz') + psycopg2.connect("dbname=foo bar=baz", connection_factory=f) + self.assertDsnEqual(self.args[0], 'dbname=foo bar=baz') self.assertEqual(self.args[1], f) self.assertEqual(self.args[2], False) def test_async(self): - psycopg2.connect(database='foo', application_name='baz', async=1) - self.assertDsnEqual(self.args[0], 'dbname=foo application_name=baz') + psycopg2.connect(database='foo', bar='baz', async=1) + self.assertDsnEqual(self.args[0], 'dbname=foo bar=baz') self.assertEqual(self.args[1], None) self.assert_(self.args[2]) - psycopg2.connect("dbname=foo application_name=baz", async=True) - self.assertDsnEqual(self.args[0], 'dbname=foo application_name=baz') + psycopg2.connect("dbname=foo bar=baz", async=True) + self.assertDsnEqual(self.args[0], 'dbname=foo bar=baz') self.assertEqual(self.args[1], None) self.assert_(self.args[2]) @@ -141,6 +143,16 @@ class ConnectTestCase(unittest.TestCase): psycopg2.connect('dbname=foo', user='postgres') self.assertDsnEqual(self.args[0], 'dbname=foo user=postgres') + def test_no_dsn_munging(self): + psycopg2.connect('nosuchparam=whatevs') + self.assertEqual(self.args[0], 'nosuchparam=whatevs') + + psycopg2.connect(nosuchparam='whatevs') + self.assertEqual(self.args[0], 'nosuchparam=whatevs') + + self.assertRaises(psycopg2.ProgrammingError, + psycopg2.connect, 'nosuchparam=whatevs', andthis='either') + class ExceptionsTestCase(ConnectingTestCase): def test_attributes(self): @@ -205,7 +217,8 @@ class ExceptionsTestCase(ConnectingTestCase): self.assertEqual(diag.sqlstate, '42P01') del diag - gc.collect(); gc.collect() + gc.collect() + gc.collect() assert(w() is None) @skip_copy_if_green @@ -327,7 +340,7 @@ class TestVersionDiscovery(unittest.TestCase): self.assertTrue(type(psycopg2.__libpq_version__) is int) try: self.assertTrue(type(psycopg2.extensions.libpq_version()) is int) - except NotSupportedError: + except psycopg2.NotSupportedError: self.assertTrue(psycopg2.__libpq_version__ < 90100)