diff --git a/doc/src/connection.rst b/doc/src/connection.rst index cceef1e5..2c608909 100644 --- a/doc/src/connection.rst +++ b/doc/src/connection.rst @@ -582,6 +582,16 @@ The ``connection`` class .. __: http://www.postgresql.org/docs/current/static/libpq-status.html#LIBPQ-PQTRANSACTIONSTATUS + .. method:: quote_ident(str) + + Return quoted identifier according to PostgreSQL quoting rules. + + Requires libpq >= 9.0. + + .. seealso:: libpq docs for `PQescapeIdentifier()`__. + + .. __: http://www.postgresql.org/docs/current/static/libpq-exec.html#LIBPQ-PQESCAPEIDENTIFIER + .. index:: pair: Protocol; Version diff --git a/psycopg/connection_type.c b/psycopg/connection_type.c index 2c1dddf2..9ac91447 100644 --- a/psycopg/connection_type.c +++ b/psycopg/connection_type.c @@ -733,6 +733,38 @@ psyco_conn_get_parameter_status(connectionObject *self, PyObject *args) return conn_text_from_chars(self, val); } +/* quote_ident - Quote identifier */ + +#define psyco_conn_quote_ident_doc \ +"quote_ident(str) -- Quote identifier according to PostgreSQL quoting rules.\n\n" \ +"Requires libpq >= 9.0, which provides PQescapeIdentifier()." + +static PyObject * +psyco_conn_quote_ident(connectionObject *self, PyObject *args) +{ +#if PG_VERSION_NUM >= 90000 + const char *str = NULL; + char *quoted; + PyObject *result; + + EXC_IF_CONN_CLOSED(self); + + if (!PyArg_ParseTuple(args, "s", &str)) return NULL; + + quoted = PQescapeIdentifier(self->pgconn, str, strlen(str)); + if (!quoted) { + PyErr_NoMemory(); + return NULL; + } + result = conn_text_from_chars(self, quoted); + PQfreemem(quoted); + + return result; +#else + PyErr_SetString(NotSupportedError, "PQescapeIdentifier not supported by libpq < 9.0"); +#endif +} + /* lobject method - allocate a new lobject */ @@ -991,6 +1023,8 @@ static struct PyMethodDef connectionObject_methods[] = { METH_NOARGS, psyco_conn_isexecuting_doc}, {"cancel", (PyCFunction)psyco_conn_cancel, METH_NOARGS, psyco_conn_cancel_doc}, + {"quote_ident", (PyCFunction)psyco_conn_quote_ident, + METH_VARARGS, psyco_conn_quote_ident_doc}, {NULL} }; diff --git a/psycopg/cursor_type.c b/psycopg/cursor_type.c index cd8d5ca3..5414656b 100644 --- a/psycopg/cursor_type.c +++ b/psycopg/cursor_type.c @@ -1815,7 +1815,7 @@ cursor_setup(cursorObject *self, connectionObject *conn, const char *name) Dprintf("cursor_setup: parameters: name = %s, conn = %p", name, conn); if (name) { - if (!(self->name = psycopg_escape_identifier_easy(name, 0))) { + if (!(self->name = psycopg_escape_identifier(conn, name, 0))) { return -1; } } diff --git a/psycopg/psycopg.h b/psycopg/psycopg.h index eb406fd2..461762ad 100644 --- a/psycopg/psycopg.h +++ b/psycopg/psycopg.h @@ -124,6 +124,7 @@ RAISES HIDDEN PyObject *psyco_set_error(PyObject *exc, cursorObject *curs, const HIDDEN char *psycopg_escape_string(connectionObject *conn, const char *from, Py_ssize_t len, char *to, Py_ssize_t *tolen); HIDDEN char *psycopg_escape_identifier_easy(const char *from, Py_ssize_t len); +HIDDEN char *psycopg_escape_identifier(connectionObject *conn, const char *from, Py_ssize_t len); HIDDEN int psycopg_strdup(char **to, const char *from, Py_ssize_t len); HIDDEN int psycopg_is_text_file(PyObject *f); diff --git a/psycopg/utils.c b/psycopg/utils.c index 836f6129..b57e3357 100644 --- a/psycopg/utils.c +++ b/psycopg/utils.c @@ -87,7 +87,7 @@ psycopg_escape_string(connectionObject *conn, const char *from, Py_ssize_t len, return to; } -/* Escape a string to build a valid PostgreSQL identifier +/* Escape a string to build a valid PostgreSQL identifier. * * Allocate a new buffer on the Python heap containing the new string. * 'len' is optional: if 0 the length is calculated. @@ -96,7 +96,10 @@ psycopg_escape_string(connectionObject *conn, const char *from, Py_ssize_t len, * * WARNING: this function is not so safe to allow untrusted input: it does no * check for multibyte chars. Such a function should be built on - * PQescapeIndentifier, which is only available from PostgreSQL 9.0. + * PQescapeIdentifier, which is only available from PostgreSQL 9.0. + * + * See below for psycopg_escape_identifier (which requires a connection + * object). */ char * psycopg_escape_identifier_easy(const char *from, Py_ssize_t len) @@ -124,6 +127,54 @@ psycopg_escape_identifier_easy(const char *from, Py_ssize_t len) return rv; } +/* Escape a string to build a valid PostgreSQL identifier. + * + * Allocate a new buffer on the Python heap containing the new string. + * 'len' is optional: if 0 the length is calculated. + * + * The returned string doesn't include quotes. + * + * Uses PQescapeIdentifier internally, if available. + */ +#if PG_VERSION_NUM >= 90000 +char * +psycopg_escape_identifier(connectionObject *conn, const char *from, Py_ssize_t len) +{ + char *rv; + char *str; + Py_ssize_t res_len; + + if (!len) { len = strlen(from); } + str = PQescapeIdentifier(conn->pgconn, from, len); + if (!str) { + PyErr_NoMemory(); + return NULL; + } + + res_len = strlen(str); + /* allocate enough mem, sans the quotes */ + if (!(rv = PyMem_New(char, 1 + res_len - 2))) { + free(str); + PyErr_NoMemory(); + return NULL; + } + /* de-quote */ + strncpy(rv, str + 1, res_len - 2); + rv[res_len - 2] = 0; + + PQfreemem(str); + + return rv; +} +#else +char * +psycopg_escape_identifier(connectionObject *conn, const char *from, Py_ssize_t len) +{ + (void) conn; + return psycopg_escape_identifier_easy(from, len); +} +#endif + /* Duplicate a string. * * Allocate a new buffer on the Python heap containing the new string. diff --git a/tests/test_quote.py b/tests/test_quote.py index e7b3c316..8da2b254 100755 --- a/tests/test_quote.py +++ b/tests/test_quote.py @@ -23,7 +23,7 @@ # License for more details. import sys -from testutils import unittest, ConnectingTestCase +from testutils import unittest, ConnectingTestCase, skip_before_libpq import psycopg2 import psycopg2.extensions @@ -165,6 +165,13 @@ class TestQuotedString(ConnectingTestCase): self.assertEqual(q.encoding, 'utf_8') +class TestQuotedIdentifier(ConnectingTestCase): + @skip_before_libpq(9, 0) + def test_identifier(self): + self.assertEqual(self.conn.quote_ident('blah-blah'), '"blah-blah"') + self.assertEqual(self.conn.quote_ident('quote"inside'), '"quote""inside"') + + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__)