Use {} instead of %s placeholders in SQL composition

This commit is contained in:
Daniele Varrazzo 2017-01-03 15:02:34 +01:00
parent 49461c2c39
commit a76e665567
3 changed files with 141 additions and 162 deletions

View File

@ -51,7 +51,8 @@ from the query parameters::
from psycopg2 import sql from psycopg2 import sql
cur.execute( cur.execute(
sql.SQL("insert into %s values (%%s, %%s)") % [sql.Identifier('my_table')], sql.SQL("insert into {} values (%s, %s)")
.format(sql.Identifier('my_table')),
[10, 20]) [10, 20])

View File

@ -23,24 +23,27 @@
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public # FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details. # License for more details.
import re
import sys import sys
import collections import string
from psycopg2 import extensions as ext from psycopg2 import extensions as ext
_formatter = string.Formatter()
class Composable(object): class Composable(object):
""" """
Abstract base class for objects that can be used to compose an SQL string. Abstract base class for objects that can be used to compose an SQL string.
Composables can be passed directly to `~cursor.execute()` and `!Composable` objects can be passed directly to `~cursor.execute()` and
`~cursor.executemany()`. `~cursor.executemany()`.
Composables can be joined using the ``+`` operator: the result will be `!Composable` objects can be joined using the ``+`` operator: the result
a `Composed` instance containing the objects joined. The operator ``*`` is will be a `Composed` instance containing the objects joined. The operator
also supported with an integer argument: the result is a `!Composed` ``*`` is also supported with an integer argument: the result is a
instance containing the left argument repeated as many times as requested. `!Composed` instance containing the left argument repeated as many times as
requested.
.. automethod:: as_string .. automethod:: as_string
""" """
@ -144,21 +147,22 @@ class Composed(Composable):
class SQL(Composable): class SQL(Composable):
""" """
A `Composable` representing a snippet of SQL string to be included verbatim. A `Composable` representing a snippet of SQL statement.
`!SQL` supports the ``%`` operator to incorporate variable parts of a query `!SQL` exposes `join()` and `format()` methods useful to create a template
into a template: the operator takes a sequence or mapping of `Composable` where to merge variable parts of a query (for instance field or table
(according to the style of the placeholders in the *string*) and returning names).
a `Composed` object.
Example:: Example::
>>> query = sql.SQL("select %s from %s") % [ >>> query = sql.SQL("select {} from {}").format(
... sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]), ... sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]),
... sql.Identifier('table')] ... sql.Identifier('table'))
>>> print(query.as_string(conn)) >>> print(query.as_string(conn))
select "foo", "bar" from "table" select "foo", "bar" from "table"
.. automethod:: format
.. automethod:: join .. automethod:: join
""" """
def __init__(self, string): def __init__(self, string):
@ -169,12 +173,73 @@ class SQL(Composable):
def __repr__(self): def __repr__(self):
return "sql.SQL(%r)" % (self._wrapped,) return "sql.SQL(%r)" % (self._wrapped,)
def __mod__(self, args):
return _compose(self._wrapped, args)
def as_string(self, conn_or_curs): def as_string(self, conn_or_curs):
return self._wrapped return self._wrapped
def format(self, *args, **kwargs):
"""
Merge `Composable` objects into a template.
:param `Composable` args: parameters to replace to numbered
(``{0}``, ``{1}``) or auto-numbered (``{}``) placeholders
:param `Composable` kwargs: parameters to replace to named (``{name}``)
placeholders
:return: the union of the `!SQL` string with placeholders replaced
:rtype: `Composed`
The method is similar to the Python `str.format()` method: the string
template supports auto-numbered (``{}``), numbered (``{0}``,
``{1}``...), and named placeholders (``{name}``), with positional
arguments replacing the numbered placeholders and keywords replacing
the named ones. However placeholder modifiers (``{{0!r}}``,
``{{0:<10}}``) are not supported. Only `!Composable` objects can be
passed to the template.
Example::
>>> print(sql.SQL("select * from {} where {} = %s")
... .format(sql.Identifier('people'), sql.Identifier('id'))
... .as_string(conn))
select * from "people" where "id" = %s
>>> print(sql.SQL("select * from {tbl} where {pkey} = %s")
... .format(tbl=sql.Identifier('people'), pkey=sql.Identifier('id'))
... .as_string(conn))
select * from "people" where "id" = %s
"""
rv = []
autonum = 0
for pre, name, spec, conv in _formatter.parse(self._wrapped):
if spec:
raise ValueError("no format specification supported by SQL")
if conv:
raise ValueError("no format conversion supported by SQL")
if pre:
rv.append(SQL(pre))
if name is None:
continue
if name.isdigit():
if autonum:
raise ValueError(
"cannot switch from automatic field numbering to manual")
rv.append(args[int(name)])
autonum = None
elif not name:
if autonum is None:
raise ValueError(
"cannot switch from manual field numbering to automatic")
rv.append(args[autonum])
autonum += 1
else:
rv.append(kwargs[name])
return Composed(rv)
def join(self, seq): def join(self, seq):
""" """
Join a sequence of `Composable` or a `Composed` and return a `!Composed`. Join a sequence of `Composable` or a `Composed` and return a `!Composed`.
@ -183,7 +248,8 @@ class SQL(Composable):
Example:: Example::
>>> snip - sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', 'baz'])) >>> snip = sql.SQL(', ').join(
... sql.Identifier(n) for n in ['foo', 'bar', 'baz'])
>>> print(snip.as_string(conn)) >>> print(snip.as_string(conn))
"foo", "bar", "baz" "foo", "bar", "baz"
""" """
@ -331,123 +397,6 @@ class Placeholder(Composable):
return "%s" return "%s"
re_compose = re.compile("""
% # percent sign
(?:
([%s]) # either % or s
| \( ([^\)]+) \) s # or a (named)s placeholder (named captured)
)
""", re.VERBOSE)
def _compose(sql, args=None):
"""
Merge an SQL string with some variable parts.
The *sql* string can contain placeholders such as `%s` or `%(name)s`.
If the string must contain a literal ``%`` symbol use ``%%``. Note that,
unlike `~cursor.execute()`, the replacement ``%%`` |=>| ``%`` is *always*
performed, even if there is no argument.
.. |=>| unicode:: 0x21D2 .. double right arrow
*args* must be a sequence or mapping (according to the placeholder style)
of `Composable` instances.
The value returned is a `Composed` instance obtained replacing the
arguments to the query placeholders.
"""
if args is None:
args = ()
phs = list(re_compose.finditer(sql))
# check placeholders consistent
counts = {'%': 0, 's': 0, None: 0}
for ph in phs:
counts[ph.group(1)] += 1
npos = counts['s']
nnamed = counts[None]
if npos and nnamed:
raise ValueError(
"the sql string contains both named and positional placeholders")
elif npos:
if not isinstance(args, collections.Sequence):
raise TypeError(
"the sql string expects values in a sequence, got %s instead"
% type(args).__name__)
if len(args) != npos:
raise ValueError(
"the sql string expects %s values, got %s" % (npos, len(args)))
return _compose_seq(sql, phs, args)
elif nnamed:
if not isinstance(args, collections.Mapping):
raise TypeError(
"the sql string expects values in a mapping, got %s instead"
% type(args))
return _compose_map(sql, phs, args)
else:
if isinstance(args, collections.Sequence) and args:
raise ValueError(
"the sql string expects no value, got %s instead" % len(args))
# If args are a mapping, no placeholder is an acceptable case
# Convert %% into %
return _compose_seq(sql, phs, ())
def _compose_seq(sql, phs, args):
rv = []
j = 0
for i, ph in enumerate(phs):
if i:
rv.append(SQL(sql[phs[i - 1].end():ph.start()]))
else:
rv.append(SQL(sql[0:ph.start()]))
if ph.group(1) == 's':
rv.append(args[j])
j += 1
else:
rv.append(SQL('%'))
if phs:
rv.append(SQL(sql[phs[-1].end():]))
else:
rv.append(SQL(sql))
return Composed(rv)
def _compose_map(sql, phs, args):
rv = []
for i, ph in enumerate(phs):
if i:
rv.append(SQL(sql[phs[i - 1].end():ph.start()]))
else:
rv.append(SQL(sql[0:ph.start()]))
if ph.group(2):
rv.append(args[ph.group(2)])
else:
rv.append(SQL('%'))
if phs:
rv.append(SQL(sql[phs[-1].end():]))
else:
rv.append(sql)
return Composed(rv)
# Alias # Alias
PH = Placeholder PH = Placeholder

View File

@ -25,65 +25,89 @@
import datetime as dt import datetime as dt
from testutils import unittest, ConnectingTestCase from testutils import unittest, ConnectingTestCase
import psycopg2
from psycopg2 import sql from psycopg2 import sql
class ComposeTests(ConnectingTestCase): class SqlFormatTests(ConnectingTestCase):
def test_pos(self): def test_pos(self):
s = sql.SQL("select %s from %s") \ s = sql.SQL("select {} from {}").format(
% (sql.Identifier('field'), sql.Identifier('table')) sql.Identifier('field'), sql.Identifier('table'))
s1 = s.as_string(self.conn)
self.assert_(isinstance(s1, str))
self.assertEqual(s1, 'select "field" from "table"')
def test_pos_spec(self):
s = sql.SQL("select {0} from {1}").format(
sql.Identifier('field'), sql.Identifier('table'))
s1 = s.as_string(self.conn)
self.assert_(isinstance(s1, str))
self.assertEqual(s1, 'select "field" from "table"')
s = sql.SQL("select {1} from {0}").format(
sql.Identifier('table'), sql.Identifier('field'))
s1 = s.as_string(self.conn) s1 = s.as_string(self.conn)
self.assert_(isinstance(s1, str)) self.assert_(isinstance(s1, str))
self.assertEqual(s1, 'select "field" from "table"') self.assertEqual(s1, 'select "field" from "table"')
def test_dict(self): def test_dict(self):
s = sql.SQL("select %(f)s from %(t)s") \ s = sql.SQL("select {f} from {t}").format(
% {'f': sql.Identifier('field'), 't': sql.Identifier('table')} f=sql.Identifier('field'), t=sql.Identifier('table'))
s1 = s.as_string(self.conn) s1 = s.as_string(self.conn)
self.assert_(isinstance(s1, str)) self.assert_(isinstance(s1, str))
self.assertEqual(s1, 'select "field" from "table"') self.assertEqual(s1, 'select "field" from "table"')
def test_unicode(self): def test_unicode(self):
s = sql.SQL(u"select %s from %s") \ s = sql.SQL(u"select {} from {}").format(
% (sql.Identifier(u'field'), sql.Identifier('table')) sql.Identifier(u'field'), sql.Identifier('table'))
s1 = s.as_string(self.conn) s1 = s.as_string(self.conn)
self.assert_(isinstance(s1, unicode)) self.assert_(isinstance(s1, unicode))
self.assertEqual(s1, u'select "field" from "table"') self.assertEqual(s1, u'select "field" from "table"')
def test_compose_literal(self): def test_compose_literal(self):
s = sql.SQL("select %s;") % [sql.Literal(dt.date(2016, 12, 31))] s = sql.SQL("select {};").format(sql.Literal(dt.date(2016, 12, 31)))
s1 = s.as_string(self.conn) s1 = s.as_string(self.conn)
self.assertEqual(s1, "select '2016-12-31'::date;") self.assertEqual(s1, "select '2016-12-31'::date;")
def test_compose_empty(self): def test_compose_empty(self):
s = sql.SQL("select foo;") % () s = sql.SQL("select foo;").format()
s1 = s.as_string(self.conn) s1 = s.as_string(self.conn)
self.assertEqual(s1, "select foo;") self.assertEqual(s1, "select foo;")
def test_percent_escape(self): def test_percent_escape(self):
s = sql.SQL("42 %% %s") % [sql.Literal(7)] s = sql.SQL("42 % {}").format(sql.Literal(7))
s1 = s.as_string(self.conn) s1 = s.as_string(self.conn)
self.assertEqual(s1, "42 % 7") self.assertEqual(s1, "42 % 7")
s = sql.SQL("42 %% 7") % [] def test_braces_escape(self):
s1 = s.as_string(self.conn) s = sql.SQL("{{{}}}").format(sql.Literal(7))
self.assertEqual(s1, "42 % 7") self.assertEqual(s.as_string(self.conn), "{7}")
s = sql.SQL("{{1,{}}}").format(sql.Literal(7))
self.assertEqual(s.as_string(self.conn), "{1,7}")
def test_compose_badnargs(self): def test_compose_badnargs(self):
self.assertRaises(ValueError, sql.SQL("select foo;").__mod__, [10]) self.assertRaises(IndexError, sql.SQL("select {};").format)
self.assertRaises(ValueError, sql.SQL("select %s;").__mod__, []) self.assertRaises(ValueError, sql.SQL("select {} {1};").format, 10, 20)
self.assertRaises(ValueError, sql.SQL("select %s;").__mod__, [10, 20]) self.assertRaises(ValueError, sql.SQL("select {0} {};").format, 10, 20)
def test_compose_bad_args_type(self): def test_compose_bad_args_type(self):
self.assertRaises(TypeError, sql.SQL("select %s;").__mod__, {'a': 10}) self.assertRaises(IndexError, sql.SQL("select {};").format, a=10)
self.assertRaises(TypeError, sql.SQL("select %(x)s;").__mod__, [10]) self.assertRaises(KeyError, sql.SQL("select {x};").format, 10)
def test_must_be_composable(self):
self.assertRaises(TypeError, sql.SQL("select {};").format, 'foo')
self.assertRaises(TypeError, sql.SQL("select {};").format, 10)
def test_no_modifiers(self):
self.assertRaises(ValueError, sql.SQL("select {a!r};").format, a=10)
self.assertRaises(ValueError, sql.SQL("select {a:<};").format, a=10)
def test_must_be_adaptable(self): def test_must_be_adaptable(self):
class Foo(object): class Foo(object):
pass pass
self.assertRaises(TypeError, self.assertRaises(psycopg2.ProgrammingError,
sql.SQL("select %s;").__mod__, [Foo()]) sql.SQL("select {};").format(sql.Literal(Foo())).as_string, self.conn)
def test_execute(self): def test_execute(self):
cur = self.conn.cursor() cur = self.conn.cursor()
@ -93,11 +117,10 @@ class ComposeTests(ConnectingTestCase):
foo text, bar text, "ba'z" text) foo text, bar text, "ba'z" text)
""") """)
cur.execute( cur.execute(
sql.SQL("insert into %s (id, %s) values (%%s, %s)") % [ sql.SQL("insert into {} (id, {}) values (%s, {})").format(
sql.Identifier('test_compose'), sql.Identifier('test_compose'),
sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])), sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])),
(sql.PH() * 3).join(', '), (sql.PH() * 3).join(', ')),
],
(10, 'a', 'b', 'c')) (10, 'a', 'b', 'c'))
cur.execute("select * from test_compose") cur.execute("select * from test_compose")
@ -111,11 +134,10 @@ class ComposeTests(ConnectingTestCase):
foo text, bar text, "ba'z" text) foo text, bar text, "ba'z" text)
""") """)
cur.executemany( cur.executemany(
sql.SQL("insert into %s (id, %s) values (%%s, %s)") % [ sql.SQL("insert into {} (id, {}) values (%s, {})").format(
sql.Identifier('test_compose'), sql.Identifier('test_compose'),
sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])), sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])),
(sql.PH() * 3).join(', '), (sql.PH() * 3).join(', ')),
],
[(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')]) [(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')])
cur.execute("select * from test_compose") cur.execute("select * from test_compose")
@ -169,6 +191,13 @@ class LiteralTests(ConnectingTestCase):
sql.Literal(dt.date(2017, 1, 1)).as_string(self.conn), sql.Literal(dt.date(2017, 1, 1)).as_string(self.conn),
"'2017-01-01'::date") "'2017-01-01'::date")
def test_must_be_adaptable(self):
class Foo(object):
pass
self.assertRaises(psycopg2.ProgrammingError,
sql.Literal(Foo()).as_string, self.conn)
class SQLTests(ConnectingTestCase): class SQLTests(ConnectingTestCase):
def test_class(self): def test_class(self):