diff --git a/lib/sql.py b/lib/sql.py index ae7c4bab..0bd1e31b 100644 --- a/lib/sql.py +++ b/lib/sql.py @@ -23,6 +23,8 @@ # FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public # License for more details. +import re +import collections from psycopg2 import extensions as ext @@ -187,8 +189,102 @@ class Placeholder(Composible): return "%s" -def compose(sql, args=()): - raise NotImplementedError +re_compose = re.compile(""" + % # percent sign + (?: + ([%s]) # either % or s + | \( ([^\)]+) \) s # or a (named)s placeholder (named captured) + ) + """, re.VERBOSE) + + +def compose(sql, args=None): + phs = list(re_compose.finditer(sql)) + + # check placeholders consistent + counts = {'%': 0, 's': 0, None: 0} + for ph in phs: + counts[ph.group(1)] += 1 + + npos = counts['s'] + nnamed = counts[None] + + if npos and nnamed: + raise ValueError( + "the sql string contains both named and positional placeholders") + + elif npos: + if not isinstance(args, collections.Sequence): + raise TypeError( + "the sql string expects values in a sequence, got %s instead" + % type(args).__name__) + + if len(args) != npos: + raise ValueError( + "the sql string expects %s values, got %s" % (npos, len(args))) + + return _compose_seq(sql, phs, args) + + elif nnamed: + if not isinstance(args, collections.Mapping): + raise TypeError( + "the sql string expects values in a mapping, got %s instead" + % type(args)) + + return _compose_map(sql, phs, args) + + else: + if not isinstance(args, collections.Sequence) and args: + raise TypeError( + "the sql string expects no value, got %s instead" % len(args)) + # If args are a mapping, no placeholder is an acceptable case + + # Convert %% into % + return _compose_seq(sql, phs, ()) + + +def _compose_seq(sql, phs, args): + rv = [] + j = 0 + for i, ph in enumerate(phs): + if i: + rv.append(SQL(sql[phs[i - 1].end():ph.start()])) + else: + rv.append(SQL(sql[0:ph.start()])) + + if ph.group(1) == 's': + rv.append(args[j]) + j += 1 + else: + rv.append(SQL('%')) + + if phs: + rv.append(SQL(sql[phs[-1].end():])) + else: + rv.append(sql) + + return Composed(rv) + + +def _compose_map(sql, phs, args): + rv = [] + for i, ph in enumerate(phs): + if i: + rv.append(SQL(sql[phs[i - 1].end():ph.start()])) + else: + rv.append(SQL(sql[0:ph.start()])) + + if ph.group(2): + rv.append(args[ph.group(2)]) + else: + rv.append(SQL('%')) + + if phs: + rv.append(SQL(sql[phs[-1].end():])) + else: + rv.append(sql) + + return Composed(rv) # Alias diff --git a/psycopg/cursor_type.c b/psycopg/cursor_type.c index a7303c68..485c69ec 100644 --- a/psycopg/cursor_type.c +++ b/psycopg/cursor_type.c @@ -267,10 +267,34 @@ _mogrify(PyObject *var, PyObject *fmt, cursorObject *curs, PyObject **new) return 0; } +/* Return 1 if `obj` is a `psycopg2.sql.Composible` instance, else 0 + * Set an exception and return -1 in case of error. + */ +RAISES_NEG static int +_curs_is_composible(PyObject *obj) +{ + int rv = -1; + PyObject *m = NULL; + PyObject *comp = NULL; + + if (!(m = PyImport_ImportModule("psycopg2.sql"))) { goto exit; } + if (!(comp = PyObject_GetAttrString(m, "Composible"))) { goto exit; } + rv = PyObject_IsInstance(obj, comp); + +exit: + Py_XDECREF(comp); + Py_XDECREF(m); + return rv; + +} + static PyObject *_psyco_curs_validate_sql_basic( cursorObject *self, PyObject *sql ) { + PyObject *rv = NULL; + int comp; + /* Performs very basic validation on an incoming SQL string. Returns a new reference to a str instance on success; NULL on failure, after having set an exception. */ @@ -278,26 +302,32 @@ static PyObject *_psyco_curs_validate_sql_basic( if (!sql || !PyObject_IsTrue(sql)) { psyco_set_error(ProgrammingError, self, "can't execute an empty query"); - goto fail; + goto exit; } if (Bytes_Check(sql)) { /* Necessary for ref-count symmetry with the unicode case: */ Py_INCREF(sql); + rv = sql; } else if (PyUnicode_Check(sql)) { - if (!(sql = conn_encode(self->conn, sql))) { goto fail; } + if (!(rv = conn_encode(self->conn, sql))) { goto exit; } + } + else if (0 != (comp = _curs_is_composible(sql))) { + if (comp < 0) { goto exit; } + if (!(rv = PyObject_CallMethod(sql, "as_string", "O", self->conn))) { + goto exit; + } } else { /* the is not unicode or string, raise an error */ PyErr_SetString(PyExc_TypeError, - "argument 1 must be a string or unicode object"); - goto fail; + "argument 1 must be a string or unicode object"); + goto exit; } - return sql; /* new reference */ - fail: - return NULL; +exit: + return rv; } /* Merge together a query string and its arguments.