From 7aba8b3ed0483c675d757bf52c8ce9456c9aeeb1 Mon Sep 17 00:00:00 2001 From: Oleksandr Shulgin Date: Tue, 27 Oct 2015 12:54:10 +0100 Subject: [PATCH] Rework psycopg2.connect() interface. --- doc/src/extensions.rst | 22 +++++ lib/__init__.py | 49 +--------- lib/extensions.py | 3 +- psycopg/psycopg.h | 6 ++ psycopg/psycopgmodule.c | 210 ++++++++++++++++++++++++++++++++++++++-- psycopg/utils.c | 44 +++++++++ tests/test_module.py | 32 ++++-- 7 files changed, 303 insertions(+), 63 deletions(-) diff --git a/doc/src/extensions.rst b/doc/src/extensions.rst index d96cca4f..dcaa2340 100644 --- a/doc/src/extensions.rst +++ b/doc/src/extensions.rst @@ -24,6 +24,28 @@ functionalities defined by the |DBAPI|_. >>> psycopg2.extensions.parse_dsn('dbname=test user=postgres password=secret') {'password': 'secret', 'user': 'postgres', 'dbname': 'test'} +.. function:: make_dsn(**kwargs) + + Wrap keyword parameters into a connection string, applying necessary + quoting and escaping any special characters (namely, single quote and + backslash). + + Example (note the order of parameters in the resulting string is + arbitrary):: + + >>> psycopg2.extensions.make_dsn(dbname='test', user='postgres', password='secret') + 'user=postgres dbname=test password=secret' + + As a special case, the *database* keyword is translated to *dbname*:: + + >>> psycopg2.extensions.make_dsn(database='test') + 'dbname=test' + + An example of quoting (using `print()` for clarity):: + + >>> print(psycopg2.extensions.make_dsn(database='test', password="some\\thing ''special")) + password='some\\thing \'\'special' dbname=test + .. class:: connection(dsn, async=False) Is the class usually returned by the `~psycopg2.connect()` function. diff --git a/lib/__init__.py b/lib/__init__.py index 994b15a8..39dd12e2 100644 --- a/lib/__init__.py +++ b/lib/__init__.py @@ -56,7 +56,7 @@ from psycopg2._psycopg import Error, Warning, DataError, DatabaseError, Programm from psycopg2._psycopg import IntegrityError, InterfaceError, InternalError from psycopg2._psycopg import NotSupportedError, OperationalError -from psycopg2._psycopg import _connect, apilevel, threadsafety, paramstyle +from psycopg2._psycopg import _connect, parse_args, apilevel, threadsafety, paramstyle from psycopg2._psycopg import __version__, __libpq_version__ from psycopg2 import tz @@ -80,27 +80,8 @@ else: _ext.register_adapter(Decimal, Adapter) del Decimal, Adapter -import re - -def _param_escape(s, - re_escape=re.compile(r"([\\'])"), - re_space=re.compile(r'\s')): - """ - Apply the escaping rule required by PQconnectdb - """ - if not s: return "''" - - s = re_escape.sub(r'\\\1', s) - if re_space.search(s): - s = "'" + s + "'" - - return s - -del re - def connect(dsn=None, - database=None, user=None, password=None, host=None, port=None, connection_factory=None, cursor_factory=None, async=False, **kwargs): """ Create a new database connection. @@ -135,33 +116,7 @@ def connect(dsn=None, library: the list of supported parameters depends on the library version. """ - items = [] - if database is not None: - items.append(('dbname', database)) - if user is not None: - items.append(('user', user)) - if password is not None: - items.append(('password', password)) - if host is not None: - items.append(('host', host)) - if port is not None: - items.append(('port', port)) - - items.extend([(k, v) for (k, v) in kwargs.iteritems() if v is not None]) - - if dsn is not None and items: - raise TypeError( - "'%s' is an invalid keyword argument when the dsn is specified" - % items[0][0]) - - if dsn is None: - if not items: - raise TypeError('missing dsn and no parameters') - else: - dsn = " ".join(["%s=%s" % (k, _param_escape(str(v))) - for (k, v) in items]) - - conn = _connect(dsn, connection_factory=connection_factory, async=async) + conn = _connect(dsn, connection_factory, async, **kwargs) if cursor_factory is not None: conn.cursor_factory = cursor_factory diff --git a/lib/extensions.py b/lib/extensions.py index b40e28b8..f99ed939 100644 --- a/lib/extensions.py +++ b/lib/extensions.py @@ -56,7 +56,8 @@ try: except ImportError: pass -from psycopg2._psycopg import adapt, adapters, encodings, connection, cursor, lobject, Xid, libpq_version, parse_dsn, quote_ident +from psycopg2._psycopg import adapt, adapters, encodings, connection, cursor, lobject, Xid, libpq_version +from psycopg2._psycopg import parse_dsn, make_dsn, quote_ident from psycopg2._psycopg import string_types, binary_types, new_type, new_array_type, register_type from psycopg2._psycopg import ISQLQuote, Notify, Diagnostics, Column diff --git a/psycopg/psycopg.h b/psycopg/psycopg.h index eb406fd2..770de7c6 100644 --- a/psycopg/psycopg.h +++ b/psycopg/psycopg.h @@ -119,11 +119,17 @@ typedef struct cursorObject cursorObject; typedef struct connectionObject connectionObject; /* some utility functions */ +HIDDEN PyObject *psyco_parse_args(PyObject *self, PyObject *args, PyObject *kwargs); +HIDDEN PyObject *psyco_parse_dsn(PyObject *self, PyObject *args, PyObject *kwargs); +HIDDEN PyObject *psyco_make_dsn(PyObject *self, PyObject *args, PyObject *kwargs); + RAISES HIDDEN PyObject *psyco_set_error(PyObject *exc, cursorObject *curs, const char *msg); HIDDEN char *psycopg_escape_string(connectionObject *conn, const char *from, Py_ssize_t len, char *to, Py_ssize_t *tolen); HIDDEN char *psycopg_escape_identifier_easy(const char *from, Py_ssize_t len); +HIDDEN char *psycopg_escape_conninfo(const char *from, Py_ssize_t len); + HIDDEN int psycopg_strdup(char **to, const char *from, Py_ssize_t len); HIDDEN int psycopg_is_text_file(PyObject *f); diff --git a/psycopg/psycopgmodule.c b/psycopg/psycopgmodule.c index cf70a4ad..03b115d0 100644 --- a/psycopg/psycopgmodule.c +++ b/psycopg/psycopgmodule.c @@ -70,24 +70,104 @@ HIDDEN PyObject *psyco_null = NULL; /* The type of the cursor.description items */ HIDDEN PyObject *psyco_DescriptionType = NULL; + +/* finds a keyword or positional arg (pops it from kwargs if found there) */ +static PyObject * +parse_arg(int pos, char *name, PyObject *defval, PyObject *args, PyObject *kwargs) +{ + Py_ssize_t nargs = PyTuple_GET_SIZE(args); + PyObject *val = NULL; + + if (kwargs && PyMapping_HasKeyString(kwargs, name)) { + val = PyMapping_GetItemString(kwargs, name); + Py_XINCREF(val); + PyMapping_DelItemString(kwargs, name); /* pop from the kwargs dict! */ + } + if (nargs > pos) { + if (!val) { + val = PyTuple_GET_ITEM(args, pos); + Py_XINCREF(val); + } else { + PyErr_Format(PyExc_TypeError, + "parse_args() got multiple values for keyword argument '%s'", name); + return NULL; + } + } + if (!val) { + val = defval; + Py_XINCREF(val); + } + + return val; +} + + +#define psyco_parse_args_doc \ +"parse_args(...) -- parse connection parameters.\n\n" \ +"Return a tuple of (dsn, connection_factory, async)" + +PyObject * +psyco_parse_args(PyObject *self, PyObject *args, PyObject *kwargs) +{ + Py_ssize_t nargs = PyTuple_GET_SIZE(args); + PyObject *dsn = NULL; + PyObject *factory = NULL; + PyObject *async = NULL; + PyObject *res = NULL; + + if (nargs > 3) { + PyErr_Format(PyExc_TypeError, + "parse_args() takes at most 3 arguments (%d given)", (int)nargs); + goto exit; + } + /* parse and remove all keywords we know, so they are not interpreted as part of DSN */ + if (!(dsn = parse_arg(0, "dsn", Py_None, args, kwargs))) { goto exit; } + if (!(factory = parse_arg(1, "connection_factory", Py_None, + args, kwargs))) { goto exit; } + if (!(async = parse_arg(2, "async", Py_False, args, kwargs))) { goto exit; } + + if (kwargs && PyMapping_Size(kwargs) > 0) { + if (dsn == Py_None) { + Py_DECREF(dsn); + if (!(dsn = psyco_make_dsn(NULL, NULL, kwargs))) { goto exit; } + } else { + PyErr_SetString(PyExc_TypeError, "both dsn and parameters given"); + goto exit; + } + } else { + if (dsn == Py_None) { + PyErr_SetString(PyExc_TypeError, "missing dsn and no parameters"); + goto exit; + } + } + + res = PyTuple_Pack(3, dsn, factory, async); + +exit: + Py_XDECREF(dsn); + Py_XDECREF(factory); + Py_XDECREF(async); + + return res; +} + + /** connect module-level function **/ #define psyco_connect_doc \ -"_connect(dsn, [connection_factory], [async]) -- New database connection.\n\n" +"_connect(dsn, [connection_factory], [async], **kwargs) -- New database connection.\n\n" static PyObject * psyco_connect(PyObject *self, PyObject *args, PyObject *keywds) { PyObject *conn = NULL; + PyObject *tuple = NULL; PyObject *factory = NULL; const char *dsn = NULL; int async = 0; - static char *kwlist[] = {"dsn", "connection_factory", "async", NULL}; + if (!(tuple = psyco_parse_args(self, args, keywds))) { goto exit; } - if (!PyArg_ParseTupleAndKeywords(args, keywds, "s|Oi", kwlist, - &dsn, &factory, &async)) { - return NULL; - } + if (!PyArg_ParseTuple(tuple, "s|Oi", &dsn, &factory, &async)) { goto exit; } Dprintf("psyco_connect: dsn = '%s', async = %d", dsn, async); @@ -109,12 +189,16 @@ psyco_connect(PyObject *self, PyObject *args, PyObject *keywds) conn = PyObject_CallFunction(factory, "si", dsn, async); } +exit: + Py_XDECREF(tuple); + return conn; } + #define psyco_parse_dsn_doc "parse_dsn(dsn) -> dict" -static PyObject * +PyObject * psyco_parse_dsn(PyObject *self, PyObject *args, PyObject *kwargs) { char *err = NULL; @@ -166,6 +250,114 @@ exit: } +#define psyco_make_dsn_doc "make_dsn(**kwargs) -> str" + +PyObject * +psyco_make_dsn(PyObject *self, PyObject *args, PyObject *kwargs) +{ + Py_ssize_t len, pos; + PyObject *res = NULL; + PyObject *key = NULL, *value = NULL; + PyObject *newkey, *newval; + PyObject *dict = NULL; + char *str = NULL, *p, *q; + + if (args && (len = PyTuple_Size(args)) > 0) { + PyErr_Format(PyExc_TypeError, "make_dsn() takes no arguments (%d given)", (int)len); + goto exit; + } + if (kwargs == NULL) { + return Text_FromUTF8(""); + } + + /* iterate through kwargs, calculating the total resulting string + length and saving prepared key/values to a temp. dict */ + if (!(dict = PyDict_New())) { goto exit; } + + len = 0; + pos = 0; + while (PyDict_Next(kwargs, &pos, &key, &value)) { + if (value == NULL || value == Py_None) { continue; } + + Py_INCREF(key); /* for ensure_bytes */ + if (!(newkey = psycopg_ensure_bytes(key))) { goto exit; } + + /* special handling of 'database' keyword */ + if (strcmp(Bytes_AsString(newkey), "database") == 0) { + key = Bytes_FromString("dbname"); + Py_DECREF(newkey); + } else { + key = newkey; + } + + /* now transform the value */ + if (Bytes_CheckExact(value)) { + Py_INCREF(value); + } else if (PyUnicode_CheckExact(value)) { + if (!(value = PyUnicode_AsUTF8String(value))) { goto exit; } + } else { + /* this could be port=5432, so we need to get the text representation */ + if (!(value = PyObject_Str(value))) { goto exit; } + /* and still ensure it's bytes() (but no need to incref here) */ + if (!(value = psycopg_ensure_bytes(value))) { goto exit; } + } + + /* passing NULL for plen checks for NIL bytes in content and errors out */ + if (Bytes_AsStringAndSize(value, &str, NULL) < 0) { goto exit; } + /* escape any special chars */ + if (!(str = psycopg_escape_conninfo(str, 0))) { goto exit; } + if (!(newval = Bytes_FromString(str))) { + goto exit; + } + PyMem_Free(str); + str = NULL; + Py_DECREF(value); + value = newval; + + /* finally put into the temp. dict */ + if (PyDict_SetItem(dict, key, value) < 0) { goto exit; } + + len += Bytes_GET_SIZE(key) + Bytes_GET_SIZE(value) + 2; /* =, space or NIL */ + + Py_DECREF(key); + Py_DECREF(value); + } + key = NULL; + value = NULL; + + if (!(str = PyMem_Malloc(len))) { + PyErr_NoMemory(); + goto exit; + } + + p = str; + pos = 0; + while (PyDict_Next(dict, &pos, &newkey, &newval)) { + if (p != str) { + *(p++) = ' '; + } + if (Bytes_AsStringAndSize(newkey, &q, &len) < 0) { goto exit; } + strncpy(p, q, len); + p += len; + *(p++) = '='; + if (Bytes_AsStringAndSize(newval, &q, &len) < 0) { goto exit; } + strncpy(p, q, len); + p += len; + } + *p = '\0'; + + res = Text_FromUTF8AndSize(str, p - str); + +exit: + PyMem_Free(str); + Py_XDECREF(key); + Py_XDECREF(value); + Py_XDECREF(dict); + + return res; +} + + #define psyco_quote_ident_doc \ "quote_ident(str, conn_or_curs) -> str -- wrapper around PQescapeIdentifier\n\n" \ ":Parameters:\n" \ @@ -820,8 +1012,12 @@ error: static PyMethodDef psycopgMethods[] = { {"_connect", (PyCFunction)psyco_connect, METH_VARARGS|METH_KEYWORDS, psyco_connect_doc}, + {"parse_args", (PyCFunction)psyco_parse_args, + METH_VARARGS|METH_KEYWORDS, psyco_parse_args_doc}, {"parse_dsn", (PyCFunction)psyco_parse_dsn, METH_VARARGS|METH_KEYWORDS, psyco_parse_dsn_doc}, + {"make_dsn", (PyCFunction)psyco_make_dsn, + METH_VARARGS|METH_KEYWORDS, psyco_make_dsn_doc}, {"quote_ident", (PyCFunction)psyco_quote_ident, METH_VARARGS|METH_KEYWORDS, psyco_quote_ident_doc}, {"adapt", (PyCFunction)psyco_microprotocols_adapt, diff --git a/psycopg/utils.c b/psycopg/utils.c index ec8e47c8..e9dc3ba6 100644 --- a/psycopg/utils.c +++ b/psycopg/utils.c @@ -124,6 +124,50 @@ psycopg_escape_identifier_easy(const char *from, Py_ssize_t len) return rv; } +char * +psycopg_escape_conninfo(const char *from, Py_ssize_t len) +{ + char *rv = NULL; + const char *src; + const char *end; + char *dst; + int space = 0; + + if (!len) { len = strlen(from); } + end = from + len; + + if (!(rv = PyMem_Malloc(3 + 2 * len))) { + PyErr_NoMemory(); + return NULL; + } + + /* check for any whitespace or empty string */ + if (from < end && *from) { + for (src = from; src < end && *src; ++src) { + if (isspace(*src)) { + space = 1; + break; + } + } + } else { + /* empty string: we should produce '' */ + space = 1; + } + + dst = rv; + if (space) { *(dst++) = '\''; } + /* scan and copy */ + for (src = from; src < end && *src; ++src, ++dst) { + if (*src == '\'' || *src == '\\') + *(dst++) = '\\'; + *dst = *src; + } + if (space) { *(dst++) = '\''; } + *dst = '\0'; + + return rv; +} + /* Duplicate a string. * * Allocate a new buffer on the Python heap containing the new string. diff --git a/tests/test_module.py b/tests/test_module.py index 62b85ee2..528f79c5 100755 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -34,11 +34,11 @@ import psycopg2 class ConnectTestCase(unittest.TestCase): def setUp(self): self.args = None - def conect_stub(dsn, connection_factory=None, async=False): - self.args = (dsn, connection_factory, async) + def connect_stub(*args, **kwargs): + self.args = psycopg2.parse_args(*args, **kwargs) self._connect_orig = psycopg2._connect - psycopg2._connect = conect_stub + psycopg2._connect = connect_stub def tearDown(self): psycopg2._connect = self._connect_orig @@ -91,29 +91,45 @@ class ConnectTestCase(unittest.TestCase): pass psycopg2.connect(database='foo', bar='baz', connection_factory=f) - self.assertEqual(self.args[0], 'dbname=foo bar=baz') + dsn = " %s " % self.args[0] + self.assertIn(" dbname=foo ", dsn) + self.assertIn(" bar=baz ", dsn) self.assertEqual(self.args[1], f) self.assertEqual(self.args[2], False) psycopg2.connect("dbname=foo bar=baz", connection_factory=f) - self.assertEqual(self.args[0], 'dbname=foo bar=baz') + dsn = " %s " % self.args[0] + self.assertIn(" dbname=foo ", dsn) + self.assertIn(" bar=baz ", dsn) self.assertEqual(self.args[1], f) self.assertEqual(self.args[2], False) def test_async(self): psycopg2.connect(database='foo', bar='baz', async=1) - self.assertEqual(self.args[0], 'dbname=foo bar=baz') + dsn = " %s " % self.args[0] + self.assertIn(" dbname=foo ", dsn) + self.assertIn(" bar=baz ", dsn) self.assertEqual(self.args[1], None) self.assert_(self.args[2]) psycopg2.connect("dbname=foo bar=baz", async=True) - self.assertEqual(self.args[0], 'dbname=foo bar=baz') + dsn = " %s " % self.args[0] + self.assertIn(" dbname=foo ", dsn) + self.assertIn(" bar=baz ", dsn) self.assertEqual(self.args[1], None) self.assert_(self.args[2]) + def test_int_port_param(self): + psycopg2.connect(database='sony', port=6543) + dsn = " %s " % self.args[0] + self.assertIn(" dbname=sony ", dsn) + self.assertIn(" port=6543 ", dsn) + def test_empty_param(self): psycopg2.connect(database='sony', password='') - self.assertEqual(self.args[0], "dbname=sony password=''") + dsn = " %s " % self.args[0] + self.assertIn(" dbname=sony ", dsn) + self.assertIn(" password='' ", dsn) def test_escape(self): psycopg2.connect(database='hello world')