From bc2aefeacface7f8636a8609b4f09ba40eb4f7b1 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Tue, 5 Oct 2010 03:13:44 +0100 Subject: [PATCH] cursor.mogrify() accepts unicode queries. --- ChangeLog | 2 ++ psycopg/cursor_type.c | 82 ++++++++++++++++++++++++------------------- tests/test_cursor.py | 31 ++++++++++++++++ 3 files changed, 78 insertions(+), 37 deletions(-) diff --git a/ChangeLog b/ChangeLog index 7a57b70a..e5afb85f 100644 --- a/ChangeLog +++ b/ChangeLog @@ -2,6 +2,8 @@ * psycopg/cursor_type.c: Common code in execute() and mogrify() merged. + * psycopg/cursor_type.c: cursor.mogrify() accepts unicode queries. + 2010-09-23 Daniele Varrazzo * lib/errorcodes.py: Added PostgreSQL 9.0 error codes. diff --git a/psycopg/cursor_type.c b/psycopg/cursor_type.c index ca9813f9..8bc7c40d 100644 --- a/psycopg/cursor_type.c +++ b/psycopg/cursor_type.c @@ -567,11 +567,53 @@ psyco_curs_executemany(cursorObject *self, PyObject *args, PyObject *kwargs) #define psyco_curs_mogrify_doc \ "mogrify(query, vars=None) -> str -- Return query after vars binding." +static PyObject * +_psyco_curs_mogrify(cursorObject *self, + PyObject *operation, PyObject *vars) +{ + PyObject *fquery = NULL, *cvt = NULL; + + operation = _psyco_curs_validate_sql_basic(self, operation); + if (operation == NULL) { goto cleanup; } + + Dprintf("psyco_curs_mogrify: starting mogrify"); + + /* here we are, and we have a sequence or a dictionary filled with + objects to be substituted (bound variables). we try to be smart and do + the right thing (i.e., what the user expects) */ + + if (vars && vars != Py_None) + { + if (_mogrify(vars, operation, self->conn, &cvt) == -1) { + goto cleanup; + } + } + + if (vars && cvt) { + if (!(fquery = _psyco_curs_merge_query_args(self, operation, cvt))) { + goto cleanup; + } + + Dprintf("psyco_curs_mogrify: cvt->refcnt = " FORMAT_CODE_PY_SSIZE_T + ", fquery->refcnt = " FORMAT_CODE_PY_SSIZE_T, + cvt->ob_refcnt, fquery->ob_refcnt); + } + else { + fquery = operation; + Py_INCREF(fquery); + } + +cleanup: + Py_XDECREF(operation); + Py_XDECREF(cvt); + + return fquery; +} + static PyObject * psyco_curs_mogrify(cursorObject *self, PyObject *args, PyObject *kwargs) { - PyObject *vars = NULL, *cvt = NULL, *operation = NULL; - PyObject *fquery; + PyObject *vars = NULL, *operation = NULL; static char *kwlist[] = {"query", "vars", NULL}; @@ -580,43 +622,9 @@ psyco_curs_mogrify(cursorObject *self, PyObject *args, PyObject *kwargs) return NULL; } - if (PyUnicode_Check(operation)) { - PyErr_SetString(NotSupportedError, - "unicode queries not yet supported"); - return NULL; - } - EXC_IF_CURS_CLOSED(self); - IFCLEARPGRES(self->pgres); - /* note that we don't overwrite the last query executed on the cursor, we - just *return* the new query with bound variables - - TODO: refactor the common mogrification code (see psycopg_curs_execute - for comments, the code is amost identical) */ - - if (vars) - { - if(_mogrify(vars, operation, self->conn, &cvt) == -1) return NULL; - } - - if (vars && cvt) { - if (!(fquery = _psyco_curs_merge_query_args(self, operation, cvt))) { - return NULL; - } - - Dprintf("psyco_curs_execute: cvt->refcnt = " FORMAT_CODE_PY_SSIZE_T - ", fquery->refcnt = " FORMAT_CODE_PY_SSIZE_T, - cvt->ob_refcnt, fquery->ob_refcnt - ); - Py_DECREF(cvt); - } - else { - fquery = operation; - Py_INCREF(operation); - } - - return fquery; + return _psyco_curs_mogrify(self, operation, vars); } #endif diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 38c368a9..a2c3b247 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -21,6 +21,37 @@ class CursorTests(unittest.TestCase): cur.close() conn.close() + def test_mogrify_unicode(self): + conn = self.connect() + cur = conn.cursor() + + # test consistency between execute and mogrify. + + # unicode query containing only ascii data + cur.execute(u"SELECT 'foo';") + self.assertEqual('foo', cur.fetchone()[0]) + self.assertEqual("SELECT 'foo';", cur.mogrify(u"SELECT 'foo';")) + + conn.set_client_encoding('UTF8') + snowman = u"\u2603" + + # unicode query with non-ascii data + cur.execute(u"SELECT '%s';" % snowman) + self.assertEqual(snowman.encode('utf8'), cur.fetchone()[0]) + self.assertEqual("SELECT '%s';" % snowman.encode('utf8'), + cur.mogrify(u"SELECT '%s';" % snowman).replace("E'", "'")) + + # unicode args + cur.execute("SELECT %s;", (snowman,)) + self.assertEqual(snowman.encode("utf-8"), cur.fetchone()[0]) + self.assertEqual("SELECT '%s';" % snowman.encode('utf8'), + cur.mogrify("SELECT %s;", (snowman,)).replace("E'", "'")) + + # unicode query and args + cur.execute(u"SELECT %s;", (snowman,)) + self.assertEqual(snowman.encode("utf-8"), cur.fetchone()[0]) + self.assertEqual("SELECT '%s';" % snowman.encode('utf8'), + cur.mogrify(u"SELECT %s;", (snowman,)).replace("E'", "'")) def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__)