diff --git a/psycopg/cursor_type.c b/psycopg/cursor_type.c index 159b9de8..b9295bcc 100644 --- a/psycopg/cursor_type.c +++ b/psycopg/cursor_type.c @@ -78,8 +78,8 @@ psyco_curs_close(cursorObject *self, PyObject *args) static int _mogrify(PyObject *var, PyObject *fmt, connectionObject *conn, PyObject **new) { - PyObject *key, *value, *n, *item; - char *d, *c; + PyObject *key, *value, *n; + const char *d, *c; Py_ssize_t index = 0; int force = 0, kind = 0; @@ -90,19 +90,26 @@ _mogrify(PyObject *var, PyObject *fmt, connectionObject *conn, PyObject **new) c = Bytes_AsString(fmt); while(*c) { - /* handle plain percent symbol in format string */ - if (c[0] == '%' && c[1] == '%') { - c+=2; force = 1; + if (*c++ != '%') { + /* a regular character */ + continue; } + switch (*c) { + + /* handle plain percent symbol in format string */ + case '%': + ++c; + force = 1; + break; + /* if we find '%(' then this is a dictionary, we: 1/ find the matching ')' and extract the key name 2/ locate the value in the dictionary (or return an error) 3/ mogrify the value into something usefull (quoting)... 4/ ...and add it to the new dictionary to be used as argument */ - else if (c[0] == '%' && c[1] == '(') { - + case '(': /* check if some crazy guy mixed formats */ if (kind == 2) { Py_XDECREF(n); @@ -113,10 +120,10 @@ _mogrify(PyObject *var, PyObject *fmt, connectionObject *conn, PyObject **new) kind = 1; /* let's have d point the end of the argument */ - for (d = c + 2; *d && *d != ')'; d++); + for (d = c + 1; *d && *d != ')' && *d != '%'; d++); if (*d == ')') { - key = Text_FromUTF8AndSize(c+2, (Py_ssize_t) (d-c-2)); + key = Text_FromUTF8AndSize(c+1, (Py_ssize_t) (d-c-1)); value = PyObject_GetItem(var, key); /* key has refcnt 1, value the original value + 1 */ @@ -135,11 +142,9 @@ _mogrify(PyObject *var, PyObject *fmt, connectionObject *conn, PyObject **new) n = PyDict_New(); } - if ((item = PyObject_GetItem(n, key)) == NULL) { + if (0 == PyDict_Contains(n, key)) { PyObject *t = NULL; - PyErr_Clear(); - /* None is always converted to NULL; this is an optimization over the adapting code and can go away in the future if somebody finds a None adapter useful. */ @@ -148,13 +153,6 @@ _mogrify(PyObject *var, PyObject *fmt, connectionObject *conn, PyObject **new) t = psyco_null; PyDict_SetItem(n, key, t); /* t is a new object, refcnt = 1, key is at 2 */ - - /* if the value is None we need to substitute the - formatting char with 's' (FIXME: this should not be - necessary if we drop support for formats other than - %s!) */ - while (*d && !isalpha(*d)) d++; - if (*d) *d = 's'; } else { t = microprotocol_getquoted(value, conn); @@ -176,20 +174,21 @@ _mogrify(PyObject *var, PyObject *fmt, connectionObject *conn, PyObject **new) if it was added to the dictionary directly; good */ Py_XDECREF(value); } - else { - /* we have an item with one extra refcnt here, zap! */ - Py_DECREF(item); - } Py_DECREF(key); /* key has the original refcnt now */ Dprintf("_mogrify: after value refcnt: " - FORMAT_CODE_PY_SSIZE_T, - Py_REFCNT(value) - ); + FORMAT_CODE_PY_SSIZE_T, Py_REFCNT(value)); } - c = d; - } + else { + /* we found %( but not a ) */ + Py_XDECREF(n); + psyco_set_error(ProgrammingError, (PyObject*)conn, + "incomplete placeholder: '%(' without ')'", NULL, NULL); + return -1; + } + c = d + 1; /* after the ) */ + break; - else if (c[0] == '%' && c[1] != '(') { + default: /* this is a format that expects a tuple; it is much easier, because we don't need to check the old/new dictionary for keys */ @@ -218,13 +217,9 @@ _mogrify(PyObject *var, PyObject *fmt, connectionObject *conn, PyObject **new) } /* let's have d point just after the '%' */ - d = c+1; - if (value == Py_None) { Py_INCREF(psyco_null); PyTuple_SET_ITEM(n, index, psyco_null); - while (*d && !isalpha(*d)) d++; - if (*d) *d = 's'; Py_DECREF(value); } else { @@ -240,12 +235,8 @@ _mogrify(PyObject *var, PyObject *fmt, connectionObject *conn, PyObject **new) return -1; } } - c = d; index += 1; } - else { - c++; - } } if (force && n == NULL) diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 65050c81..b8a8b662 100755 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -91,6 +91,17 @@ class CursorTests(unittest.TestCase): self.assertEqual(b('SELECT 10.3;'), cur.mogrify("SELECT %s;", (Decimal("10.3"),))) + def test_bad_placeholder(self): + cur = self.conn.cursor() + self.assertRaises(psycopg2.ProgrammingError, + cur.mogrify, "select %(foo", {}) + self.assertRaises(psycopg2.ProgrammingError, + cur.mogrify, "select %(foo", {'foo': 1}) + self.assertRaises(psycopg2.ProgrammingError, + cur.mogrify, "select %(foo, %(bar)", {'foo': 1}) + self.assertRaises(psycopg2.ProgrammingError, + cur.mogrify, "select %(foo, %(bar)", {'foo': 1, 'bar': 2}) + def test_cast(self): curs = self.conn.cursor()