Some cleanup in mogrify

- Raise an exception on incomplete placeholders.
- Minor speedups.
- Don't change the string in place (??!!) if the placeholder is not s
  and the value is null.

The latter point can be done because downstream we don't accept anything
different from s anyway (in the Bytes_Format function).

Notice that now the format string is constant whatever the arguments.
This means that executemany is still more inefficient than it should be
as mogrify may work only on the parameters. However this is an
implementation only worthwhile if we start supporting real parameters.

Let's talk about that for the next release.
This commit is contained in:
Daniele Varrazzo 2011-02-17 23:18:05 +00:00
parent b6d6fbbe8c
commit 99b3c72312
2 changed files with 39 additions and 37 deletions

View File

@ -78,8 +78,8 @@ psyco_curs_close(cursorObject *self, PyObject *args)
static int static int
_mogrify(PyObject *var, PyObject *fmt, connectionObject *conn, PyObject **new) _mogrify(PyObject *var, PyObject *fmt, connectionObject *conn, PyObject **new)
{ {
PyObject *key, *value, *n, *item; PyObject *key, *value, *n;
char *d, *c; const char *d, *c;
Py_ssize_t index = 0; Py_ssize_t index = 0;
int force = 0, kind = 0; int force = 0, kind = 0;
@ -90,19 +90,26 @@ _mogrify(PyObject *var, PyObject *fmt, connectionObject *conn, PyObject **new)
c = Bytes_AsString(fmt); c = Bytes_AsString(fmt);
while(*c) { while(*c) {
/* handle plain percent symbol in format string */ if (*c++ != '%') {
if (c[0] == '%' && c[1] == '%') { /* a regular character */
c+=2; force = 1; continue;
} }
switch (*c) {
/* handle plain percent symbol in format string */
case '%':
++c;
force = 1;
break;
/* if we find '%(' then this is a dictionary, we: /* if we find '%(' then this is a dictionary, we:
1/ find the matching ')' and extract the key name 1/ find the matching ')' and extract the key name
2/ locate the value in the dictionary (or return an error) 2/ locate the value in the dictionary (or return an error)
3/ mogrify the value into something usefull (quoting)... 3/ mogrify the value into something usefull (quoting)...
4/ ...and add it to the new dictionary to be used as argument 4/ ...and add it to the new dictionary to be used as argument
*/ */
else if (c[0] == '%' && c[1] == '(') { case '(':
/* check if some crazy guy mixed formats */ /* check if some crazy guy mixed formats */
if (kind == 2) { if (kind == 2) {
Py_XDECREF(n); Py_XDECREF(n);
@ -113,10 +120,10 @@ _mogrify(PyObject *var, PyObject *fmt, connectionObject *conn, PyObject **new)
kind = 1; kind = 1;
/* let's have d point the end of the argument */ /* let's have d point the end of the argument */
for (d = c + 2; *d && *d != ')'; d++); for (d = c + 1; *d && *d != ')' && *d != '%'; d++);
if (*d == ')') { if (*d == ')') {
key = Text_FromUTF8AndSize(c+2, (Py_ssize_t) (d-c-2)); key = Text_FromUTF8AndSize(c+1, (Py_ssize_t) (d-c-1));
value = PyObject_GetItem(var, key); value = PyObject_GetItem(var, key);
/* key has refcnt 1, value the original value + 1 */ /* key has refcnt 1, value the original value + 1 */
@ -135,11 +142,9 @@ _mogrify(PyObject *var, PyObject *fmt, connectionObject *conn, PyObject **new)
n = PyDict_New(); n = PyDict_New();
} }
if ((item = PyObject_GetItem(n, key)) == NULL) { if (0 == PyDict_Contains(n, key)) {
PyObject *t = NULL; PyObject *t = NULL;
PyErr_Clear();
/* None is always converted to NULL; this is an /* None is always converted to NULL; this is an
optimization over the adapting code and can go away in optimization over the adapting code and can go away in
the future if somebody finds a None adapter useful. */ the future if somebody finds a None adapter useful. */
@ -148,13 +153,6 @@ _mogrify(PyObject *var, PyObject *fmt, connectionObject *conn, PyObject **new)
t = psyco_null; t = psyco_null;
PyDict_SetItem(n, key, t); PyDict_SetItem(n, key, t);
/* t is a new object, refcnt = 1, key is at 2 */ /* t is a new object, refcnt = 1, key is at 2 */
/* if the value is None we need to substitute the
formatting char with 's' (FIXME: this should not be
necessary if we drop support for formats other than
%s!) */
while (*d && !isalpha(*d)) d++;
if (*d) *d = 's';
} }
else { else {
t = microprotocol_getquoted(value, conn); t = microprotocol_getquoted(value, conn);
@ -176,20 +174,21 @@ _mogrify(PyObject *var, PyObject *fmt, connectionObject *conn, PyObject **new)
if it was added to the dictionary directly; good */ if it was added to the dictionary directly; good */
Py_XDECREF(value); Py_XDECREF(value);
} }
else {
/* we have an item with one extra refcnt here, zap! */
Py_DECREF(item);
}
Py_DECREF(key); /* key has the original refcnt now */ Py_DECREF(key); /* key has the original refcnt now */
Dprintf("_mogrify: after value refcnt: " Dprintf("_mogrify: after value refcnt: "
FORMAT_CODE_PY_SSIZE_T, FORMAT_CODE_PY_SSIZE_T, Py_REFCNT(value));
Py_REFCNT(value)
);
} }
c = d; else {
} /* we found %( but not a ) */
Py_XDECREF(n);
psyco_set_error(ProgrammingError, (PyObject*)conn,
"incomplete placeholder: '%(' without ')'", NULL, NULL);
return -1;
}
c = d + 1; /* after the ) */
break;
else if (c[0] == '%' && c[1] != '(') { default:
/* this is a format that expects a tuple; it is much easier, /* this is a format that expects a tuple; it is much easier,
because we don't need to check the old/new dictionary for because we don't need to check the old/new dictionary for
keys */ keys */
@ -218,13 +217,9 @@ _mogrify(PyObject *var, PyObject *fmt, connectionObject *conn, PyObject **new)
} }
/* let's have d point just after the '%' */ /* let's have d point just after the '%' */
d = c+1;
if (value == Py_None) { if (value == Py_None) {
Py_INCREF(psyco_null); Py_INCREF(psyco_null);
PyTuple_SET_ITEM(n, index, psyco_null); PyTuple_SET_ITEM(n, index, psyco_null);
while (*d && !isalpha(*d)) d++;
if (*d) *d = 's';
Py_DECREF(value); Py_DECREF(value);
} }
else { else {
@ -240,12 +235,8 @@ _mogrify(PyObject *var, PyObject *fmt, connectionObject *conn, PyObject **new)
return -1; return -1;
} }
} }
c = d;
index += 1; index += 1;
} }
else {
c++;
}
} }
if (force && n == NULL) if (force && n == NULL)

View File

@ -91,6 +91,17 @@ class CursorTests(unittest.TestCase):
self.assertEqual(b('SELECT 10.3;'), self.assertEqual(b('SELECT 10.3;'),
cur.mogrify("SELECT %s;", (Decimal("10.3"),))) cur.mogrify("SELECT %s;", (Decimal("10.3"),)))
def test_bad_placeholder(self):
cur = self.conn.cursor()
self.assertRaises(psycopg2.ProgrammingError,
cur.mogrify, "select %(foo", {})
self.assertRaises(psycopg2.ProgrammingError,
cur.mogrify, "select %(foo", {'foo': 1})
self.assertRaises(psycopg2.ProgrammingError,
cur.mogrify, "select %(foo, %(bar)", {'foo': 1})
self.assertRaises(psycopg2.ProgrammingError,
cur.mogrify, "select %(foo, %(bar)", {'foo': 1, 'bar': 2})
def test_cast(self): def test_cast(self):
curs = self.conn.cursor() curs = self.conn.cursor()