diff --git a/NEWS b/NEWS index 4bb25bc1..519dc866 100644 --- a/NEWS +++ b/NEWS @@ -1,6 +1,8 @@ What's new in psycopg 2.4.3 --------------------------- + - connect() supports all the keyword arguments supported by the + database - Added 'new_array_type()' function for easy creation of array typecasters. - Added support for arrays of hstores and composite types (ticket #66). diff --git a/doc/src/module.rst b/doc/src/module.rst index fb709624..29f4b636 100644 --- a/doc/src/module.rst +++ b/doc/src/module.rst @@ -20,7 +20,7 @@ The module interface respects the standard defined in the |DBAPI|_. Create a new database session and return a new `connection` object. - You can specify the connection parameters either as a string:: + The connection parameters can be specified either as a string:: conn = psycopg2.connect("dbname=test user=postgres password=secret") @@ -28,17 +28,23 @@ The module interface respects the standard defined in the |DBAPI|_. conn = psycopg2.connect(database="test", user="postgres", password="secret") - The full list of available parameters is: - + The basic connection parameters are: + - `!dbname` -- the database name (only in dsn string) - `!database` -- the database name (only as keyword argument) - `!user` -- user name used to authenticate - `!password` -- password used to authenticate - `!host` -- database host address (defaults to UNIX socket if not provided) - `!port` -- connection port number (defaults to 5432 if not provided) - - `!sslmode` -- `SSL TCP/IP negotiation`__ mode - .. __: http://www.postgresql.org/docs/9.0/static/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS + Any other connection parameter supported by the client library/server can + be passed either in the connection string or as keyword. See the + PostgreSQL documentation for a complete `list of supported parameters`__. + Also note that the same parameters can be passed to the client library + using `environment variables`__. + + .. __: http://www.postgresql.org/docs/9.1/static/libpq-connect.html#LIBPQ-PQCONNECTDBPARAMS + .. __: http://www.postgresql.org/docs/9.1/static/libpq-envars.html Using the *connection_factory* parameter a different class or connections factory can be specified. It should be a callable object @@ -48,6 +54,10 @@ The module interface respects the standard defined in the |DBAPI|_. Using *async*\=1 an asynchronous connection will be created: see :ref:`async-support` to know about advantages and limitations. + .. versionchanged:: 2.4.3 + any keyword argument is passed to the connection. Previously only the + basic parameters (plus `!sslmode`) were supported as keywords. + .. extension:: The parameters *connection_factory* and *async* are Psycopg extensions diff --git a/lib/__init__.py b/lib/__init__.py index e04d35b4..0a8ed0f5 100644 --- a/lib/__init__.py +++ b/lib/__init__.py @@ -73,7 +73,7 @@ from psycopg2._psycopg import Error, Warning, DataError, DatabaseError, Programm from psycopg2._psycopg import IntegrityError, InterfaceError, InternalError from psycopg2._psycopg import NotSupportedError, OperationalError -from psycopg2._psycopg import connect, apilevel, threadsafety, paramstyle +from psycopg2._psycopg import _connect, apilevel, threadsafety, paramstyle from psycopg2._psycopg import __version__ from psycopg2 import tz @@ -97,5 +97,87 @@ else: _ext.register_adapter(Decimal, Adapter) del Decimal, Adapter +import re + +def _param_escape(s, + re_escape=re.compile(r"([\\'])"), + re_space=re.compile(r'\s')): + """ + Apply the escaping rule required by PQconnectdb + """ + if not s: return "''" + + s = re_escape.sub(r'\\\1', s) + if re_space.search(s): + s = "'" + s + "'" + + return s + +del re + + +def connect(dsn=None, + database=None, user=None, password=None, host=None, port=None, + connection_factory=None, async=False, **kwargs): + """ + Create a new database connection. + + The connection parameters can be specified either as a string: + + conn = psycopg2.connect("dbname=test user=postgres password=secret") + + or using a set of keyword arguments: + + conn = psycopg2.connect(database="test", user="postgres", password="secret") + + The basic connection parameters are: + + - *dbname*: the database name (only in dsn string) + - *database*: the database name (only as keyword argument) + - *user*: user name used to authenticate + - *password*: password used to authenticate + - *host*: database host address (defaults to UNIX socket if not provided) + - *port*: connection port number (defaults to 5432 if not provided) + + Using the *connection_factory* parameter a different class or connections + factory can be specified. It should be a callable object taking a dsn + argument. + + Using *async*=True an asynchronous connection will be created. + + Any other keyword parameter will be passed to the underlying client + library: the list of supported parameter depends on the library version. + + """ + if dsn is None: + # Note: reproducing the behaviour of the previous C implementation: + # keyword are silently swallowed if a DSN is specified. I would have + # raised an exception. File under "histerical raisins". + items = [] + if database is not None: + items.append(('dbname', database)) + if user is not None: + items.append(('user', user)) + if password is not None: + items.append(('password', password)) + if host is not None: + items.append(('host', host)) + # Reproducing the previous C implementation behaviour: swallow a + # negative port. The libpq would raise an exception for it. + if port is not None and int(port) > 0: + items.append(('port', port)) + + items.extend( + [(k, v) for (k, v) in kwargs.iteritems() if v is not None]) + dsn = " ".join(["%s=%s" % (k, _param_escape(str(v))) + for (k, v) in items]) + + if not dsn: + raise InterfaceError('missing dsn and no parameters') + + return _connect(dsn, + connection_factory=connection_factory, async=async) + + __all__ = filter(lambda k: not k.startswith('_'), locals().keys()) diff --git a/psycopg/psycopgmodule.c b/psycopg/psycopgmodule.c index 2c7e3fbf..3b2b0609 100644 --- a/psycopg/psycopgmodule.c +++ b/psycopg/psycopgmodule.c @@ -75,177 +75,43 @@ HIDDEN PyObject *psyco_DescriptionType = NULL; /** connect module-level function **/ #define psyco_connect_doc \ -"connect(dsn, ...) -- Create a new database connection.\n\n" \ -"This function supports two different but equivalent sets of arguments.\n" \ -"A single data source name or ``dsn`` string can be used to specify the\n" \ -"connection parameters, as follows::\n\n" \ -" psycopg2.connect(\"dbname=xxx user=xxx ...\")\n\n" \ -"If ``dsn`` is not provided it is possible to pass the parameters as\n" \ -"keyword arguments; e.g.::\n\n" \ -" psycopg2.connect(database='xxx', user='xxx', ...)\n\n" \ -"The full list of available parameters is:\n\n" \ -"- ``dbname`` -- database name (only in 'dsn')\n" \ -"- ``database`` -- database name (only as keyword argument)\n" \ -"- ``host`` -- host address (defaults to UNIX socket if not provided)\n" \ -"- ``port`` -- port number (defaults to 5432 if not provided)\n" \ -"- ``user`` -- user name used to authenticate\n" \ -"- ``password`` -- password used to authenticate\n" \ -"- ``sslmode`` -- SSL mode (see PostgreSQL documentation)\n\n" \ -"- ``async`` -- if the connection should provide asynchronous API\n\n" \ -"If the ``connection_factory`` keyword argument is not provided this\n" \ -"function always return an instance of the `connection` class.\n" \ -"Else the given sub-class of `extensions.connection` will be used to\n" \ -"instantiate the connection object.\n\n" \ -":return: New database connection\n" \ -":rtype: `extensions.connection`" - -static size_t -_psyco_connect_fill_dsn(char *dsn, const char *kw, const char *v, size_t i) -{ - strcpy(&dsn[i], kw); i += strlen(kw); - strcpy(&dsn[i], v); i += strlen(v); - return i; -} +"_connect(dsn, [connection_factory], [async]) -- New database connection.\n\n" static PyObject * psyco_connect(PyObject *self, PyObject *args, PyObject *keywds) { - PyObject *conn = NULL, *factory = NULL; - PyObject *pyport = NULL; - - size_t idsn=-1; - int iport=-1; - const char *dsn_static = NULL; - char *dsn_dynamic=NULL; - const char *database=NULL, *user=NULL, *password=NULL; - const char *host=NULL, *sslmode=NULL; - char port[16]; + PyObject *conn = NULL; + PyObject *factory = NULL; + const char *dsn = NULL; int async = 0; - static char *kwlist[] = {"dsn", "database", "host", "port", - "user", "password", "sslmode", - "connection_factory", "async", NULL}; + static char *kwlist[] = {"dsn", "connection_factory", "async", NULL}; - if (!PyArg_ParseTupleAndKeywords(args, keywds, "|sssOsssOi", kwlist, - &dsn_static, &database, &host, &pyport, - &user, &password, &sslmode, - &factory, &async)) { + if (!PyArg_ParseTupleAndKeywords(args, keywds, "s|Oi", kwlist, + &dsn, &factory, &async)) { return NULL; } -#if PY_MAJOR_VERSION < 3 - if (pyport && PyString_Check(pyport)) { - PyObject *pyint = PyInt_FromString(PyString_AsString(pyport), NULL, 10); - if (!pyint) goto fail; - /* Must use PyInt_AsLong rather than PyInt_AS_LONG, because - * PyInt_FromString can return a PyLongObject: */ - iport = PyInt_AsLong(pyint); - Py_DECREF(pyint); - if (iport == -1 && PyErr_Occurred()) - goto fail; - } - else if (pyport && PyInt_Check(pyport)) { - iport = PyInt_AsLong(pyport); - if (iport == -1 && PyErr_Occurred()) - goto fail; - } -#else - if (pyport && PyUnicode_Check(pyport)) { - PyObject *pyint = PyObject_CallFunction((PyObject*)&PyLong_Type, - "Oi", pyport, 10); - if (!pyint) goto fail; - iport = PyLong_AsLong(pyint); - Py_DECREF(pyint); - if (iport == -1 && PyErr_Occurred()) - goto fail; - } - else if (pyport && PyLong_Check(pyport)) { - iport = PyLong_AsLong(pyport); - if (iport == -1 && PyErr_Occurred()) - goto fail; - } -#endif - else if (pyport != NULL) { - PyErr_SetString(PyExc_TypeError, "port must be a string or int"); - goto fail; + Dprintf("psyco_connect: dsn = '%s', async = %d", dsn, async); + + /* allocate connection, fill with errors and return it */ + if (factory == NULL || factory == Py_None) { + factory = (PyObject *)&connectionType; } - if (iport > 0) - PyOS_snprintf(port, 16, "%d", iport); - - if (dsn_static == NULL) { - size_t l = 46; /* len(" dbname= user= password= host= port= sslmode=\0") */ - - if (database) l += strlen(database); - if (host) l += strlen(host); - if (iport > 0) l += strlen(port); - if (user) l += strlen(user); - if (password) l += strlen(password); - if (sslmode) l += strlen(sslmode); - - dsn_dynamic = malloc(l*sizeof(char)); - if (dsn_dynamic == NULL) { - PyErr_SetString(InterfaceError, "dynamic dsn allocation failed"); - goto fail; - } - - idsn = 0; - if (database) - idsn = _psyco_connect_fill_dsn(dsn_dynamic, " dbname=", database, idsn); - if (host) - idsn = _psyco_connect_fill_dsn(dsn_dynamic, " host=", host, idsn); - if (iport > 0) - idsn = _psyco_connect_fill_dsn(dsn_dynamic, " port=", port, idsn); - if (user) - idsn = _psyco_connect_fill_dsn(dsn_dynamic, " user=", user, idsn); - if (password) - idsn = _psyco_connect_fill_dsn(dsn_dynamic, " password=", password, idsn); - if (sslmode) - idsn = _psyco_connect_fill_dsn(dsn_dynamic, " sslmode=", sslmode, idsn); - - if (idsn > 0) { - dsn_dynamic[idsn] = '\0'; - memmove(dsn_dynamic, &dsn_dynamic[1], idsn); - } - else { - PyErr_SetString(InterfaceError, "missing dsn and no parameters"); - goto fail; - } - } - - { - const char *dsn = (dsn_static != NULL ? dsn_static : dsn_dynamic); - Dprintf("psyco_connect: dsn = '%s', async = %d", dsn, async); - - /* allocate connection, fill with errors and return it */ - if (factory == NULL) factory = (PyObject *)&connectionType; - /* Here we are breaking the connection.__init__ interface defined - * by psycopg2. So, if not requiring an async conn, avoid passing - * the async parameter. */ - /* TODO: would it be possible to avoid an additional parameter - * to the conn constructor? A subclass? (but it would require mixins - * to further subclass) Another dsn parameter (but is not really - * a connection parameter that can be configured) */ - if (!async) { + /* Here we are breaking the connection.__init__ interface defined + * by psycopg2. So, if not requiring an async conn, avoid passing + * the async parameter. */ + /* TODO: would it be possible to avoid an additional parameter + * to the conn constructor? A subclass? (but it would require mixins + * to further subclass) Another dsn parameter (but is not really + * a connection parameter that can be configured) */ + if (!async) { conn = PyObject_CallFunction(factory, "s", dsn); - } else { + } else { conn = PyObject_CallFunction(factory, "si", dsn, async); - } } - goto cleanup; - fail: - assert (PyErr_Occurred()); - if (conn != NULL) { - Py_DECREF(conn); - conn = NULL; - } - /* Fall through to cleanup: */ - cleanup: - if (dsn_dynamic != NULL) { - free(dsn_dynamic); - } - return conn; } @@ -754,7 +620,7 @@ exit: /** method table and module initialization **/ static PyMethodDef psycopgMethods[] = { - {"connect", (PyCFunction)psyco_connect, + {"_connect", (PyCFunction)psyco_connect, METH_VARARGS|METH_KEYWORDS, psyco_connect_doc}, {"adapt", (PyCFunction)psyco_microprotocols_adapt, METH_VARARGS, psyco_microprotocols_adapt_doc}, diff --git a/tests/__init__.py b/tests/__init__.py index 1114a950..479f374c 100755 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -33,6 +33,7 @@ import test_extras_dictcursor import test_dates import test_psycopg2_dbapi20 import test_quote +import test_module import test_connection import test_cursor import test_transaction @@ -71,6 +72,7 @@ def test_suite(): suite.addTest(test_types_extras.test_suite()) suite.addTest(test_lobject.test_suite()) suite.addTest(test_copy.test_suite()) + suite.addTest(test_module.test_suite()) suite.addTest(test_notify.test_suite()) suite.addTest(test_async.test_suite()) suite.addTest(test_green.test_suite()) diff --git a/tests/test_module.py b/tests/test_module.py new file mode 100755 index 00000000..9c130f3f --- /dev/null +++ b/tests/test_module.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python + +# test_module.py - unit test for the module interface +# +# Copyright (C) 2011 Daniele Varrazzo +# +# psycopg2 is free software: you can redistribute it and/or modify it +# under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# In addition, as a special exception, the copyright holders give +# permission to link this program with the OpenSSL library (or with +# modified versions of OpenSSL that use the same license as OpenSSL), +# and distribute linked combinations including the two. +# +# You must obey the GNU Lesser General Public License in all respects for +# all of the code used other than OpenSSL. +# +# psycopg2 is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +from testutils import unittest + +import psycopg2 + +class ConnectTestCase(unittest.TestCase): + def setUp(self): + self.args = None + def conect_stub(dsn, connection_factory=None, async=False): + self.args = (dsn, connection_factory, async) + + self._connect_orig = psycopg2._connect + psycopg2._connect = conect_stub + + def tearDown(self): + psycopg2._connect = self._connect_orig + + def test_there_has_to_be_something(self): + self.assertRaises(psycopg2.InterfaceError, psycopg2.connect) + self.assertRaises(psycopg2.InterfaceError, psycopg2.connect, + connection_factory=lambda dsn, async=False: None) + self.assertRaises(psycopg2.InterfaceError, psycopg2.connect, + async=True) + + def test_no_keywords(self): + psycopg2.connect('') + self.assertEqual(self.args[0], '') + self.assertEqual(self.args[1], None) + self.assertEqual(self.args[2], False) + + def test_dsn(self): + psycopg2.connect('dbname=blah x=y') + self.assertEqual(self.args[0], 'dbname=blah x=y') + self.assertEqual(self.args[1], None) + self.assertEqual(self.args[2], False) + + def test_supported_keywords(self): + psycopg2.connect(database='foo') + self.assertEqual(self.args[0], 'dbname=foo') + psycopg2.connect(user='postgres') + self.assertEqual(self.args[0], 'user=postgres') + psycopg2.connect(password='secret') + self.assertEqual(self.args[0], 'password=secret') + psycopg2.connect(port=5432) + self.assertEqual(self.args[0], 'port=5432') + psycopg2.connect(sslmode='require') + self.assertEqual(self.args[0], 'sslmode=require') + + psycopg2.connect(database='foo', + user='postgres', password='secret', port=5432) + self.assert_('dbname=foo' in self.args[0]) + self.assert_('user=postgres' in self.args[0]) + self.assert_('password=secret' in self.args[0]) + self.assert_('port=5432' in self.args[0]) + self.assertEqual(len(self.args[0].split()), 4) + + def test_generic_keywords(self): + psycopg2.connect(foo='bar') + self.assertEqual(self.args[0], 'foo=bar') + + def test_factory(self): + def f(dsn, async=False): + pass + + psycopg2.connect(database='foo', bar='baz', connection_factory=f) + self.assertEqual(self.args[0], 'dbname=foo bar=baz') + self.assertEqual(self.args[1], f) + self.assertEqual(self.args[2], False) + + psycopg2.connect("dbname=foo bar=baz", connection_factory=f) + self.assertEqual(self.args[0], 'dbname=foo bar=baz') + self.assertEqual(self.args[1], f) + self.assertEqual(self.args[2], False) + + def test_async(self): + psycopg2.connect(database='foo', bar='baz', async=1) + self.assertEqual(self.args[0], 'dbname=foo bar=baz') + self.assertEqual(self.args[1], None) + self.assert_(self.args[2]) + + psycopg2.connect("dbname=foo bar=baz", async=True) + self.assertEqual(self.args[0], 'dbname=foo bar=baz') + self.assertEqual(self.args[1], None) + self.assert_(self.args[2]) + + def test_empty_param(self): + psycopg2.connect(database='sony', password='') + self.assertEqual(self.args[0], "dbname=sony password=''") + + def test_escape(self): + psycopg2.connect(database='hello world') + self.assertEqual(self.args[0], "dbname='hello world'") + + psycopg2.connect(database=r'back\slash') + self.assertEqual(self.args[0], r"dbname=back\\slash") + + psycopg2.connect(database="quo'te") + self.assertEqual(self.args[0], r"dbname=quo\'te") + + psycopg2.connect(database="with\ttab") + self.assertEqual(self.args[0], "dbname='with\ttab'") + + psycopg2.connect(database=r"\every thing'") + self.assertEqual(self.args[0], r"dbname='\\every thing\''") + + +def test_suite(): + return unittest.TestLoader().loadTestsFromName(__name__) + +if __name__ == "__main__": + unittest.main()