From 71a168797cc7e6398222490e6de02080de842e1d Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Tue, 3 Jan 2017 17:27:01 +0100 Subject: [PATCH] Several improvements to the sql objects Comparable, iterable, content accessible --- doc/src/sql.rst | 20 ++++++ lib/sql.py | 171 +++++++++++++++++++++++++--------------------- tests/test_sql.py | 89 ++++++++++++++++++++---- 3 files changed, 189 insertions(+), 91 deletions(-) diff --git a/doc/src/sql.rst b/doc/src/sql.rst index a2ed15e0..0aee4519 100644 --- a/doc/src/sql.rst +++ b/doc/src/sql.rst @@ -58,12 +58,32 @@ from the query parameters:: .. autoclass:: Composable + .. automethod:: as_string + + .. autoclass:: SQL + .. autoattribute:: string + + .. automethod:: format + + .. automethod:: join + + .. autoclass:: Identifier + .. autoattribute:: string + .. autoclass:: Literal + .. autoattribute:: wrapped + .. autoclass:: Placeholder + .. autoattribute:: name + .. autoclass:: Composed + + .. autoattribute:: seq + + .. automethod:: join diff --git a/lib/sql.py b/lib/sql.py index 23f66a61..e4d4b14b 100644 --- a/lib/sql.py +++ b/lib/sql.py @@ -37,23 +37,28 @@ class Composable(object): Abstract base class for objects that can be used to compose an SQL string. `!Composable` objects can be passed directly to `~cursor.execute()` and - `~cursor.executemany()`. + `~cursor.executemany()` in place of the query string. `!Composable` objects can be joined using the ``+`` operator: the result will be a `Composed` instance containing the objects joined. The operator ``*`` is also supported with an integer argument: the result is a `!Composed` instance containing the left argument repeated as many times as requested. - - .. automethod:: as_string """ - def as_string(self, conn_or_curs): + def __init__(self, wrapped): + self._wrapped = wrapped + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self._wrapped) + + def as_string(self, context): """ Return the string value of the object. - The object is evaluated in the context of the *conn_or_curs* argument. + :param context: the context to evaluate the string into. + :type context: `connection` or `cursor` - The function is automatically invoked by `~cursor.execute()` and + The method is automatically invoked by `~cursor.execute()` and `~cursor.executemany()` if a `!Composable` is passed instead of the query string. """ @@ -70,14 +75,20 @@ class Composable(object): def __mul__(self, n): return Composed([self] * n) + def __eq__(self, other): + return type(self) is type(other) and self._wrapped == other._wrapped + + def __ne__(self, other): + return not self.__eq__(other) + class Composed(Composable): """ - A `Composable` object obtained concatenating a sequence of `Composable`. + A `Composable` object made of a sequence of `Composable`. - The object is usually created using `Composable` operators. However it is - possible to create a `!Composed` directly specifying a sequence of - `Composable` as arguments. + The object is usually created using `Composable` operators and methods. + However it is possible to create a `!Composed` directly specifying a + sequence of `Composable` as arguments. Example:: @@ -86,30 +97,38 @@ class Composed(Composable): >>> print(comp.as_string(conn)) insert into "table" - .. automethod:: join + `!Composed` objects are iterable (so they can be used in `SQL.join` for + instance). """ def __init__(self, seq): - self._seq = [] + wrapped = [] for i in seq: if not isinstance(i, Composable): raise TypeError( "Composed elements must be Composable, got %r instead" % i) - self._seq.append(i) + wrapped.append(i) - def __repr__(self): - return "sql.Composed(%r)" % (self._seq,) + super(Composed, self).__init__(wrapped) - def as_string(self, conn_or_curs): + @property + def seq(self): + """The list of the content of the `!Composed`.""" + return list(self._wrapped) + + def as_string(self, context): rv = [] - for i in self._seq: - rv.append(i.as_string(conn_or_curs)) + for i in self._wrapped: + rv.append(i.as_string(context)) return ''.join(rv) + def __iter__(self): + return iter(self._wrapped) + def __add__(self, other): if isinstance(other, Composed): - return Composed(self._seq + other._seq) + return Composed(self._wrapped + other._wrapped) if isinstance(other, Composable): - return Composed(self._seq + [other]) + return Composed(self._wrapped + [other]) else: return NotImplemented @@ -133,16 +152,7 @@ class Composed(Composable): raise TypeError( "Composed.join() argument must be a string or an SQL") - if len(self._seq) <= 1: - return self - - it = iter(self._seq) - rv = [it.next()] - for i in it: - rv.append(joiner) - rv.append(i) - - return Composed(rv) + return joiner.join(self) class SQL(Composable): @@ -153,6 +163,12 @@ class SQL(Composable): where to merge variable parts of a query (for instance field or table names). + The *string* doesn't undergo any form of escaping, so it is not suitable to + represent variable identifiers or values: you should only use it to pass + constant strings representing templates or snippets of SQL statements; use + other objects such as `Identifier` or `Literal` to represent variable + parts. + Example:: >>> query = sql.SQL("select {} from {}").format( @@ -160,20 +176,18 @@ class SQL(Composable): ... sql.Identifier('table')) >>> print(query.as_string(conn)) select "foo", "bar" from "table" - - .. automethod:: format - - .. automethod:: join """ def __init__(self, string): if not isinstance(string, basestring): raise TypeError("SQL values must be strings") - self._wrapped = string + super(SQL, self).__init__(string) - def __repr__(self): - return "sql.SQL(%r)" % (self._wrapped,) + @property + def string(self): + """The string wrapped by the `!SQL` object.""" + return self._wrapped - def as_string(self, conn_or_curs): + def as_string(self, context): return self._wrapped def format(self, *args, **kwargs): @@ -191,9 +205,9 @@ class SQL(Composable): template supports auto-numbered (``{}``), numbered (``{0}``, ``{1}``...), and named placeholders (``{name}``), with positional arguments replacing the numbered placeholders and keywords replacing - the named ones. However placeholder modifiers (``{{0!r}}``, - ``{{0:<10}}``) are not supported. Only `!Composable` objects can be - passed to the template. + the named ones. However placeholder modifiers (``{0!r}``, ``{0:<10}``) + are not supported. Only `!Composable` objects can be passed to the + template. Example:: @@ -242,9 +256,14 @@ class SQL(Composable): def join(self, seq): """ - Join a sequence of `Composable` or a `Composed` and return a `!Composed`. + Join a sequence of `Composable`. - Use the object *string* to separate the *seq* elements. + :param seq: the elements to join. + :type seq: iterable of `!Composable` + + Use the `!SQL` object's *string* to separate the elements in *seq*. + Note that `Composed` objects are iterable too, so they can be used as + argument for this method. Example:: @@ -253,9 +272,6 @@ class SQL(Composable): >>> print(snip.as_string(conn)) "foo", "bar", "baz" """ - if isinstance(seq, Composed): - seq = seq._seq - rv = [] it = iter(seq) try: @@ -294,13 +310,15 @@ class Identifier(Composable): if not isinstance(string, basestring): raise TypeError("SQL identifiers must be strings") - self._wrapped = string + super(Identifier, self).__init__(string) - def __repr__(self): - return "sql.Identifier(%r)" % (self._wrapped,) + @property + def string(self): + """The string wrapped by the `Identifier`.""" + return self._wrapped - def as_string(self, conn_or_curs): - return ext.quote_ident(self._wrapped, conn_or_curs) + def as_string(self, context): + return ext.quote_ident(self._wrapped, context) class Literal(Composable): @@ -323,20 +341,19 @@ class Literal(Composable): 'foo', 'ba''r', 42 """ - def __init__(self, wrapped): - self._wrapped = wrapped + @property + def wrapped(self): + """The object wrapped by the `!Literal`.""" + return self._wrapped - def __repr__(self): - return "sql.Literal(%r)" % (self._wrapped,) - - def as_string(self, conn_or_curs): + def as_string(self, context): # is it a connection or cursor? - if isinstance(conn_or_curs, ext.connection): - conn = conn_or_curs - elif isinstance(conn_or_curs, ext.cursor): - conn = conn_or_curs.connection + if isinstance(context, ext.connection): + conn = context + elif isinstance(context, ext.cursor): + conn = context.connection else: - raise TypeError("conn_or_curs must be a connection or a cursor") + raise TypeError("context must be a connection or a cursor") a = ext.adapt(self._wrapped) if hasattr(a, 'prepare'): @@ -362,15 +379,15 @@ class Placeholder(Composable): >>> names = ['foo', 'bar', 'baz'] - >>> q1 = sql.SQL("insert into table (%s) values (%s)") % [ + >>> q1 = sql.SQL("insert into table ({}) values ({})").format( ... sql.SQL(', ').join(map(sql.Identifier, names)), - ... sql.SQL(', ').join(sql.Placeholder() * 3)] + ... sql.SQL(', ').join(sql.Placeholder() * len(names))) >>> print(q1.as_string(conn)) insert into table ("foo", "bar", "baz") values (%s, %s, %s) - >>> q2 = sql.SQL("insert into table (%s) values (%s)") % [ - ... sql.SQL(', ').join(map(sql.Identifier, names)), - ... sql.SQL(', ').join(map(sql.Placeholder, names))] + >>> q2 = sql.SQL("insert into table ({}) values ({})").format( + ... sql.SQL(', ').join(map(sql.Identifier, names)), + ... sql.SQL(', ').join(map(sql.Placeholder, names))) >>> print(q2.as_string(conn)) insert into table ("foo", "bar", "baz") values (%(foo)s, %(bar)s, %(baz)s) @@ -384,22 +401,24 @@ class Placeholder(Composable): elif name is not None: raise TypeError("expected string or None as name, got %r" % name) - self._name = name + super(Placeholder, self).__init__(name) + + @property + def name(self): + """The name of the `!Placeholder`.""" + return self._wrapped def __repr__(self): - return "sql.Placeholder(%r)" % ( - self._name if self._name is not None else '',) + return "Placeholder(%r)" % ( + self._wrapped if self._wrapped is not None else '',) - def as_string(self, conn_or_curs): - if self._name is not None: - return "%%(%s)s" % self._name + def as_string(self, context): + if self._wrapped is not None: + return "%%(%s)s" % self._wrapped else: return "%s" -# Alias -PH = Placeholder - # Literals NULL = SQL("NULL") DEFAULT = SQL("DEFAULT") diff --git a/tests/test_sql.py b/tests/test_sql.py index c2268fda..21c761e2 100755 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -120,7 +120,7 @@ class SqlFormatTests(ConnectingTestCase): sql.SQL("insert into {} (id, {}) values (%s, {})").format( sql.Identifier('test_compose'), sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])), - (sql.PH() * 3).join(', ')), + (sql.Placeholder() * 3).join(', ')), (10, 'a', 'b', 'c')) cur.execute("select * from test_compose") @@ -137,7 +137,7 @@ class SqlFormatTests(ConnectingTestCase): sql.SQL("insert into {} (id, {}) values (%s, {})").format( sql.Identifier('test_compose'), sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])), - (sql.PH() * 3).join(', ')), + (sql.Placeholder() * 3).join(', ')), [(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')]) cur.execute("select * from test_compose") @@ -155,11 +155,20 @@ class IdentifierTests(ConnectingTestCase): self.assertRaises(TypeError, sql.Identifier, 10) self.assertRaises(TypeError, sql.Identifier, dt.date(2016, 12, 31)) + def test_string(self): + self.assertEqual(sql.Identifier('foo').string, 'foo') + def test_repr(self): obj = sql.Identifier("fo'o") - self.assertEqual(repr(obj), 'sql.Identifier("fo\'o")') + self.assertEqual(repr(obj), 'Identifier("fo\'o")') self.assertEqual(repr(obj), str(obj)) + def test_eq(self): + self.assert_(sql.Identifier('foo') == sql.Identifier('foo')) + self.assert_(sql.Identifier('foo') != sql.Identifier('bar')) + self.assert_(sql.Identifier('foo') != 'foo') + self.assert_(sql.Identifier('foo') != sql.SQL('foo')) + def test_as_str(self): self.assertEqual(sql.Identifier('foo').as_string(self.conn), '"foo"') self.assertEqual(sql.Identifier("fo'o").as_string(self.conn), '"fo\'o"') @@ -180,9 +189,12 @@ class LiteralTests(ConnectingTestCase): self.assert_(isinstance( sql.Literal(dt.date(2016, 12, 31)), sql.Literal)) + def test_wrapped(self): + self.assertEqual(sql.Literal('foo').wrapped, 'foo') + def test_repr(self): - self.assertEqual(repr(sql.Literal("foo")), "sql.Literal('foo')") - self.assertEqual(str(sql.Literal("foo")), "sql.Literal('foo')") + self.assertEqual(repr(sql.Literal("foo")), "Literal('foo')") + self.assertEqual(str(sql.Literal("foo")), "Literal('foo')") self.assertEqual( sql.Literal("foo").as_string(self.conn).replace("E'", "'"), "'foo'") @@ -191,6 +203,12 @@ class LiteralTests(ConnectingTestCase): sql.Literal(dt.date(2017, 1, 1)).as_string(self.conn), "'2017-01-01'::date") + def test_eq(self): + self.assert_(sql.Literal('foo') == sql.Literal('foo')) + self.assert_(sql.Literal('foo') != sql.Literal('bar')) + self.assert_(sql.Literal('foo') != 'foo') + self.assert_(sql.Literal('foo') != sql.SQL('foo')) + def test_must_be_adaptable(self): class Foo(object): pass @@ -209,11 +227,20 @@ class SQLTests(ConnectingTestCase): self.assertRaises(TypeError, sql.SQL, 10) self.assertRaises(TypeError, sql.SQL, dt.date(2016, 12, 31)) - def test_str(self): - self.assertEqual(repr(sql.SQL("foo")), "sql.SQL('foo')") - self.assertEqual(str(sql.SQL("foo")), "sql.SQL('foo')") + def test_string(self): + self.assertEqual(sql.SQL('foo').string, 'foo') + + def test_repr(self): + self.assertEqual(repr(sql.SQL("foo")), "SQL('foo')") + self.assertEqual(str(sql.SQL("foo")), "SQL('foo')") self.assertEqual(sql.SQL("foo").as_string(self.conn), "foo") + def test_eq(self): + self.assert_(sql.SQL('foo') == sql.SQL('foo')) + self.assert_(sql.SQL('foo') != sql.SQL('bar')) + self.assert_(sql.SQL('foo') != 'foo') + self.assert_(sql.SQL('foo') != sql.Literal('foo')) + def test_sum(self): obj = sql.SQL("foo") + sql.SQL("bar") self.assert_(isinstance(obj, sql.Composed)) @@ -241,6 +268,9 @@ class SQLTests(ConnectingTestCase): self.assert_(isinstance(obj, sql.Composed)) self.assertEqual(obj.as_string(self.conn), '"foo", bar, 42') + obj = sql.SQL(", ").join([]) + self.assertEqual(obj, sql.Composed([])) + class ComposedTest(ConnectingTestCase): def test_class(self): @@ -249,9 +279,20 @@ class ComposedTest(ConnectingTestCase): def test_repr(self): obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")]) self.assertEqual(repr(obj), - """sql.Composed([sql.Literal('foo'), sql.Identifier("b'ar")])""") + """Composed([Literal('foo'), Identifier("b'ar")])""") self.assertEqual(str(obj), repr(obj)) + def test_seq(self): + l = [sql.SQL('foo'), sql.Literal('bar'), sql.Identifier('baz')] + self.assertEqual(sql.Composed(l).seq, l) + + def test_eq(self): + l = [sql.Literal("foo"), sql.Identifier("b'ar")] + l2 = [sql.Literal("foo"), sql.Literal("b'ar")] + self.assert_(sql.Composed(l) == sql.Composed(list(l))) + self.assert_(sql.Composed(l) != l) + self.assert_(sql.Composed(l) != sql.Composed(l2)) + def test_join(self): obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")]) obj = obj.join(", ") @@ -275,27 +316,45 @@ class ComposedTest(ConnectingTestCase): self.assert_(isinstance(obj, sql.Composed)) self.assertEqual(obj.as_string(self.conn), "foo 'bar'") + def test_iter(self): + obj = sql.Composed([sql.SQL("foo"), sql.SQL('bar')]) + it = iter(obj) + i = it.next() + self.assertEqual(i, sql.SQL('foo')) + i = it.next() + self.assertEqual(i, sql.SQL('bar')) + self.assertRaises(StopIteration, it.next) + class PlaceholderTest(ConnectingTestCase): def test_class(self): self.assert_(issubclass(sql.Placeholder, sql.Composable)) - def test_alias(self): - self.assert_(sql.Placeholder is sql.PH) + def test_name(self): + self.assertEqual(sql.Placeholder().name, None) + self.assertEqual(sql.Placeholder('foo').name, 'foo') def test_repr(self): - self.assert_(str(sql.Placeholder()), 'sql.Placeholder()') - self.assert_(repr(sql.Placeholder()), 'sql.Placeholder()') + self.assert_(str(sql.Placeholder()), 'Placeholder()') + self.assert_(repr(sql.Placeholder()), 'Placeholder()') self.assert_(sql.Placeholder().as_string(self.conn), '%s') def test_repr_name(self): - self.assert_(str(sql.Placeholder('foo')), "sql.Placeholder('foo')") - self.assert_(repr(sql.Placeholder('foo')), "sql.Placeholder('foo')") + self.assert_(str(sql.Placeholder('foo')), "Placeholder('foo')") + self.assert_(repr(sql.Placeholder('foo')), "Placeholder('foo')") self.assert_(sql.Placeholder('foo').as_string(self.conn), '%(foo)s') def test_bad_name(self): self.assertRaises(ValueError, sql.Placeholder, ')') + def test_eq(self): + self.assert_(sql.Placeholder('foo') == sql.Placeholder('foo')) + self.assert_(sql.Placeholder('foo') != sql.Placeholder('bar')) + self.assert_(sql.Placeholder('foo') != 'foo') + self.assert_(sql.Placeholder() == sql.Placeholder()) + self.assert_(sql.Placeholder('foo') != sql.Placeholder()) + self.assert_(sql.Placeholder('foo') != sql.Literal('foo')) + class ValuesTest(ConnectingTestCase): def test_null(self):