Validate output result from make_dsn()

The output is not necessarily munged anyway: if no keyword is passed,
validate the input but return it untouched.
This commit is contained in:
Daniele Varrazzo 2016-03-03 16:52:53 +00:00
parent 7155d06cdc
commit 7aab934ae5
5 changed files with 70 additions and 50 deletions

View File

@ -493,15 +493,19 @@ Other functions
.. function:: make_dsn(dsn=None, \*\*kwargs) .. function:: make_dsn(dsn=None, \*\*kwargs)
Create a connection string from arguments. Create a valid connection string from arguments.
Put together the arguments in *kwargs* into a connection string. If *dsn* Put together the arguments in *kwargs* into a connection string. If *dsn*
is specified too, merge the arguments coming from both the sources. If the is specified too, merge the arguments coming from both the sources. If the
same argument is specified in both the sources, the *kwargs* version same argument is specified in both the sources, the *kwargs* version
overrides the *dsn* version overrides the *dsn* version.
At least one param is required (either *dsn* or any keyword). Note that At least one parameter is required (either *dsn* or any keyword). Note
the empty string is a valid connection string. that the empty string is a valid connection string.
The input arguments are validated: the output should always be a valid
connection string (as far as `parse_dsn()` is concerned). If not raise
`~psycopg2.ProgrammingError`.
Example:: Example::
@ -516,17 +520,27 @@ Other functions
Parse connection string into a dictionary of keywords and values. Parse connection string into a dictionary of keywords and values.
Uses libpq's ``PQconninfoParse`` to parse the string according to Parsing is delegated to the libpq: different versions of the client
accepted format(s) and check for supported keywords. library may support different formats or parameters (for example,
`connection URIs`__ are only supported from libpq 9.2). Raise
`~psycopg2.ProgrammingError` if the *dsn* is not valid.
.. __: http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING
Example:: Example::
>>> from psycopg2.extensions import parse_dsn >>> from psycopg2.extensions import parse_dsn
>>> parse_dsn('dbname=test user=postgres password=secret') >>> parse_dsn('dbname=test user=postgres password=secret')
{'password': 'secret', 'user': 'postgres', 'dbname': 'test'} {'password': 'secret', 'user': 'postgres', 'dbname': 'test'}
>>> parse_dsn("postgresql://someone@example.com/somedb?connect_timeout=10")
{'host': 'example.com', 'user': 'someone', 'dbname': 'somedb', 'connect_timeout': '10'}
.. versionadded:: 2.7 .. versionadded:: 2.7
.. seealso:: libpq docs for `PQconninfoParse()`__.
.. __: http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-PQCONNINFOPARSE
.. function:: quote_ident(str, scope) .. function:: quote_ident(str, scope)

View File

@ -160,8 +160,9 @@ def make_dsn(dsn=None, **kwargs):
if dsn is None and not kwargs: if dsn is None and not kwargs:
raise TypeError('missing dsn and no parameters') raise TypeError('missing dsn and no parameters')
# If no kwarg is specified don't mung the dsn # If no kwarg is specified don't mung the dsn, but verify it
if not kwargs: if not kwargs:
parse_dsn(dsn)
return dsn return dsn
# Override the dsn with the parameters # Override the dsn with the parameters
@ -178,6 +179,10 @@ def make_dsn(dsn=None, **kwargs):
dsn = " ".join(["%s=%s" % (k, _param_escape(str(v))) dsn = " ".join(["%s=%s" % (k, _param_escape(str(v)))
for (k, v) in kwargs.iteritems()]) for (k, v) in kwargs.iteritems()])
# verify that the returned dsn is valid
parse_dsn(dsn)
return dsn return dsn

View File

@ -133,7 +133,7 @@ psyco_parse_dsn(PyObject *self, PyObject *args, PyObject *kwargs)
options = PQconninfoParse(Bytes_AS_STRING(dsn), &err); options = PQconninfoParse(Bytes_AS_STRING(dsn), &err);
if (options == NULL) { if (options == NULL) {
if (err != NULL) { if (err != NULL) {
PyErr_Format(ProgrammingError, "error parsing the dsn: %s", err); PyErr_Format(ProgrammingError, "invalid dsn: %s", err);
PQfreemem(err); PQfreemem(err);
} else { } else {
PyErr_SetString(OperationalError, "PQconninfoParse() failed"); PyErr_SetString(OperationalError, "PQconninfoParse() failed");

View File

@ -388,53 +388,64 @@ class MakeDsnTestCase(ConnectingTestCase):
dsn = ext.make_dsn('') dsn = ext.make_dsn('')
self.assertEqual(dsn, '') self.assertEqual(dsn, '')
def test_params_validation(self):
self.assertRaises(psycopg2.ProgrammingError,
ext.make_dsn, 'dbnamo=a')
self.assertRaises(psycopg2.ProgrammingError,
ext.make_dsn, dbnamo='a')
self.assertRaises(psycopg2.ProgrammingError,
ext.make_dsn, 'dbname=a', nosuchparam='b')
def test_empty_param(self): def test_empty_param(self):
dsn = ext.make_dsn(database='sony', password='') dsn = ext.make_dsn(dbname='sony', password='')
self.assertDsnEqual(dsn, "dbname=sony password=''") self.assertDsnEqual(dsn, "dbname=sony password=''")
def test_escape(self): def test_escape(self):
dsn = ext.make_dsn(database='hello world') dsn = ext.make_dsn(dbname='hello world')
self.assertEqual(dsn, "dbname='hello world'") self.assertEqual(dsn, "dbname='hello world'")
dsn = ext.make_dsn(database=r'back\slash') dsn = ext.make_dsn(dbname=r'back\slash')
self.assertEqual(dsn, r"dbname=back\\slash") self.assertEqual(dsn, r"dbname=back\\slash")
dsn = ext.make_dsn(database="quo'te") dsn = ext.make_dsn(dbname="quo'te")
self.assertEqual(dsn, r"dbname=quo\'te") self.assertEqual(dsn, r"dbname=quo\'te")
dsn = ext.make_dsn(database="with\ttab") dsn = ext.make_dsn(dbname="with\ttab")
self.assertEqual(dsn, "dbname='with\ttab'") self.assertEqual(dsn, "dbname='with\ttab'")
dsn = ext.make_dsn(database=r"\every thing'") dsn = ext.make_dsn(dbname=r"\every thing'")
self.assertEqual(dsn, r"dbname='\\every thing\''") self.assertEqual(dsn, r"dbname='\\every thing\''")
def test_database_is_a_keyword(self):
self.assertEqual(ext.make_dsn(database='sigh'), "dbname=sigh")
def test_params_merging(self): def test_params_merging(self):
dsn = ext.make_dsn('dbname=foo', database='bar') dsn = ext.make_dsn('dbname=foo host=bar', host='baz')
self.assertEqual(dsn, 'dbname=bar') self.assertDsnEqual(dsn, 'dbname=foo host=baz')
dsn = ext.make_dsn('dbname=foo', user='postgres') dsn = ext.make_dsn('dbname=foo', user='postgres')
self.assertDsnEqual(dsn, 'dbname=foo user=postgres') self.assertDsnEqual(dsn, 'dbname=foo user=postgres')
def test_no_dsn_munging(self): def test_no_dsn_munging(self):
dsn = ext.make_dsn('nosuchparam=whatevs') dsnin = 'dbname=a host=b user=c password=d'
self.assertEqual(dsn, 'nosuchparam=whatevs') dsn = ext.make_dsn(dsnin)
self.assertEqual(dsn, dsnin)
dsn = ext.make_dsn(nosuchparam='whatevs')
self.assertEqual(dsn, 'nosuchparam=whatevs')
self.assertRaises(psycopg2.ProgrammingError,
ext.make_dsn, 'nosuchparam=whatevs', andthis='either')
@skip_before_libpq(9, 2) @skip_before_libpq(9, 2)
def test_url_is_cool(self): def test_url_is_cool(self):
dsn = ext.make_dsn('postgresql://tester:secret@/test') url = 'postgresql://tester:secret@/test?application_name=wat'
self.assertEqual(dsn, 'postgresql://tester:secret@/test') dsn = ext.make_dsn(url)
self.assertEqual(dsn, url)
dsn = ext.make_dsn('postgresql://tester:secret@/test', dsn = ext.make_dsn(url, application_name='woot')
application_name='woot')
self.assertDsnEqual(dsn, self.assertDsnEqual(dsn,
'dbname=test user=tester password=secret application_name=woot') 'dbname=test user=tester password=secret application_name=woot')
self.assertRaises(psycopg2.ProgrammingError,
ext.make_dsn, 'postgresql://tester:secret@/test?nosuch=param')
self.assertRaises(psycopg2.ProgrammingError,
ext.make_dsn, url, nosuch="param")
class IsolationLevelsTestCase(ConnectingTestCase): class IsolationLevelsTestCase(ConnectingTestCase):

View File

@ -62,8 +62,8 @@ class ConnectTestCase(unittest.TestCase):
self.assertEqual(self.args[2], False) self.assertEqual(self.args[2], False)
def test_dsn(self): def test_dsn(self):
psycopg2.connect('dbname=blah x=y') psycopg2.connect('dbname=blah host=y')
self.assertEqual(self.args[0], 'dbname=blah x=y') self.assertEqual(self.args[0], 'dbname=blah host=y')
self.assertEqual(self.args[1], None) self.assertEqual(self.args[1], None)
self.assertEqual(self.args[2], False) self.assertEqual(self.args[2], False)
@ -88,31 +88,31 @@ class ConnectTestCase(unittest.TestCase):
self.assertEqual(len(self.args[0].split()), 4) self.assertEqual(len(self.args[0].split()), 4)
def test_generic_keywords(self): def test_generic_keywords(self):
psycopg2.connect(foo='bar') psycopg2.connect(options='stuff')
self.assertEqual(self.args[0], 'foo=bar') self.assertEqual(self.args[0], 'options=stuff')
def test_factory(self): def test_factory(self):
def f(dsn, async=False): def f(dsn, async=False):
pass pass
psycopg2.connect(database='foo', bar='baz', connection_factory=f) psycopg2.connect(database='foo', host='baz', connection_factory=f)
self.assertDsnEqual(self.args[0], 'dbname=foo bar=baz') self.assertDsnEqual(self.args[0], 'dbname=foo host=baz')
self.assertEqual(self.args[1], f) self.assertEqual(self.args[1], f)
self.assertEqual(self.args[2], False) self.assertEqual(self.args[2], False)
psycopg2.connect("dbname=foo bar=baz", connection_factory=f) psycopg2.connect("dbname=foo host=baz", connection_factory=f)
self.assertDsnEqual(self.args[0], 'dbname=foo bar=baz') self.assertDsnEqual(self.args[0], 'dbname=foo host=baz')
self.assertEqual(self.args[1], f) self.assertEqual(self.args[1], f)
self.assertEqual(self.args[2], False) self.assertEqual(self.args[2], False)
def test_async(self): def test_async(self):
psycopg2.connect(database='foo', bar='baz', async=1) psycopg2.connect(database='foo', host='baz', async=1)
self.assertDsnEqual(self.args[0], 'dbname=foo bar=baz') self.assertDsnEqual(self.args[0], 'dbname=foo host=baz')
self.assertEqual(self.args[1], None) self.assertEqual(self.args[1], None)
self.assert_(self.args[2]) self.assert_(self.args[2])
psycopg2.connect("dbname=foo bar=baz", async=True) psycopg2.connect("dbname=foo host=baz", async=True)
self.assertDsnEqual(self.args[0], 'dbname=foo bar=baz') self.assertDsnEqual(self.args[0], 'dbname=foo host=baz')
self.assertEqual(self.args[1], None) self.assertEqual(self.args[1], None)
self.assert_(self.args[2]) self.assert_(self.args[2])
@ -143,16 +143,6 @@ class ConnectTestCase(unittest.TestCase):
psycopg2.connect('dbname=foo', user='postgres') psycopg2.connect('dbname=foo', user='postgres')
self.assertDsnEqual(self.args[0], 'dbname=foo user=postgres') self.assertDsnEqual(self.args[0], 'dbname=foo user=postgres')
def test_no_dsn_munging(self):
psycopg2.connect('nosuchparam=whatevs')
self.assertEqual(self.args[0], 'nosuchparam=whatevs')
psycopg2.connect(nosuchparam='whatevs')
self.assertEqual(self.args[0], 'nosuchparam=whatevs')
self.assertRaises(psycopg2.ProgrammingError,
psycopg2.connect, 'nosuchparam=whatevs', andthis='either')
class ExceptionsTestCase(ConnectingTestCase): class ExceptionsTestCase(ConnectingTestCase):
def test_attributes(self): def test_attributes(self):