Added sql.compose() implementation

This commit is contained in:
Daniele Varrazzo 2017-01-01 05:23:42 +01:00
parent f11e6d82b0
commit c4a67fc1c1
2 changed files with 135 additions and 9 deletions

View File

@ -23,6 +23,8 @@
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public # FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details. # License for more details.
import re
import collections
from psycopg2 import extensions as ext from psycopg2 import extensions as ext
@ -187,8 +189,102 @@ class Placeholder(Composible):
return "%s" return "%s"
def compose(sql, args=()): re_compose = re.compile("""
raise NotImplementedError % # percent sign
(?:
([%s]) # either % or s
| \( ([^\)]+) \) s # or a (named)s placeholder (named captured)
)
""", re.VERBOSE)
def compose(sql, args=None):
phs = list(re_compose.finditer(sql))
# check placeholders consistent
counts = {'%': 0, 's': 0, None: 0}
for ph in phs:
counts[ph.group(1)] += 1
npos = counts['s']
nnamed = counts[None]
if npos and nnamed:
raise ValueError(
"the sql string contains both named and positional placeholders")
elif npos:
if not isinstance(args, collections.Sequence):
raise TypeError(
"the sql string expects values in a sequence, got %s instead"
% type(args).__name__)
if len(args) != npos:
raise ValueError(
"the sql string expects %s values, got %s" % (npos, len(args)))
return _compose_seq(sql, phs, args)
elif nnamed:
if not isinstance(args, collections.Mapping):
raise TypeError(
"the sql string expects values in a mapping, got %s instead"
% type(args))
return _compose_map(sql, phs, args)
else:
if not isinstance(args, collections.Sequence) and args:
raise TypeError(
"the sql string expects no value, got %s instead" % len(args))
# If args are a mapping, no placeholder is an acceptable case
# Convert %% into %
return _compose_seq(sql, phs, ())
def _compose_seq(sql, phs, args):
rv = []
j = 0
for i, ph in enumerate(phs):
if i:
rv.append(SQL(sql[phs[i - 1].end():ph.start()]))
else:
rv.append(SQL(sql[0:ph.start()]))
if ph.group(1) == 's':
rv.append(args[j])
j += 1
else:
rv.append(SQL('%'))
if phs:
rv.append(SQL(sql[phs[-1].end():]))
else:
rv.append(sql)
return Composed(rv)
def _compose_map(sql, phs, args):
rv = []
for i, ph in enumerate(phs):
if i:
rv.append(SQL(sql[phs[i - 1].end():ph.start()]))
else:
rv.append(SQL(sql[0:ph.start()]))
if ph.group(2):
rv.append(args[ph.group(2)])
else:
rv.append(SQL('%'))
if phs:
rv.append(SQL(sql[phs[-1].end():]))
else:
rv.append(sql)
return Composed(rv)
# Alias # Alias

View File

@ -267,10 +267,34 @@ _mogrify(PyObject *var, PyObject *fmt, cursorObject *curs, PyObject **new)
return 0; return 0;
} }
/* Return 1 if `obj` is a `psycopg2.sql.Composible` instance, else 0
* Set an exception and return -1 in case of error.
*/
RAISES_NEG static int
_curs_is_composible(PyObject *obj)
{
int rv = -1;
PyObject *m = NULL;
PyObject *comp = NULL;
if (!(m = PyImport_ImportModule("psycopg2.sql"))) { goto exit; }
if (!(comp = PyObject_GetAttrString(m, "Composible"))) { goto exit; }
rv = PyObject_IsInstance(obj, comp);
exit:
Py_XDECREF(comp);
Py_XDECREF(m);
return rv;
}
static PyObject *_psyco_curs_validate_sql_basic( static PyObject *_psyco_curs_validate_sql_basic(
cursorObject *self, PyObject *sql cursorObject *self, PyObject *sql
) )
{ {
PyObject *rv = NULL;
int comp;
/* Performs very basic validation on an incoming SQL string. /* Performs very basic validation on an incoming SQL string.
Returns a new reference to a str instance on success; NULL on failure, Returns a new reference to a str instance on success; NULL on failure,
after having set an exception. */ after having set an exception. */
@ -278,26 +302,32 @@ static PyObject *_psyco_curs_validate_sql_basic(
if (!sql || !PyObject_IsTrue(sql)) { if (!sql || !PyObject_IsTrue(sql)) {
psyco_set_error(ProgrammingError, self, psyco_set_error(ProgrammingError, self,
"can't execute an empty query"); "can't execute an empty query");
goto fail; goto exit;
} }
if (Bytes_Check(sql)) { if (Bytes_Check(sql)) {
/* Necessary for ref-count symmetry with the unicode case: */ /* Necessary for ref-count symmetry with the unicode case: */
Py_INCREF(sql); Py_INCREF(sql);
rv = sql;
} }
else if (PyUnicode_Check(sql)) { else if (PyUnicode_Check(sql)) {
if (!(sql = conn_encode(self->conn, sql))) { goto fail; } if (!(rv = conn_encode(self->conn, sql))) { goto exit; }
}
else if (0 != (comp = _curs_is_composible(sql))) {
if (comp < 0) { goto exit; }
if (!(rv = PyObject_CallMethod(sql, "as_string", "O", self->conn))) {
goto exit;
}
} }
else { else {
/* the is not unicode or string, raise an error */ /* the is not unicode or string, raise an error */
PyErr_SetString(PyExc_TypeError, PyErr_SetString(PyExc_TypeError,
"argument 1 must be a string or unicode object"); "argument 1 must be a string or unicode object");
goto fail; goto exit;
} }
return sql; /* new reference */ exit:
fail: return rv;
return NULL;
} }
/* Merge together a query string and its arguments. /* Merge together a query string and its arguments.