diff --git a/lib/sql.py b/lib/sql.py index a5043dcc..8be6542e 100644 --- a/lib/sql.py +++ b/lib/sql.py @@ -204,7 +204,7 @@ re_compose = re.compile(""" """, re.VERBOSE) -def compose(sql, args=None): +def compose(sql, args=()): phs = list(re_compose.finditer(sql)) # check placeholders consistent @@ -240,8 +240,8 @@ def compose(sql, args=None): return _compose_map(sql, phs, args) else: - if not isinstance(args, collections.Sequence) and args: - raise TypeError( + if isinstance(args, collections.Sequence) and args: + raise ValueError( "the sql string expects no value, got %s instead" % len(args)) # If args are a mapping, no placeholder is an acceptable case @@ -267,7 +267,7 @@ def _compose_seq(sql, phs, args): if phs: rv.append(SQL(sql[phs[-1].end():])) else: - rv.append(sql) + rv.append(SQL(sql)) return Composed(rv) diff --git a/tests/test_sql.py b/tests/test_sql.py index c8ec716b..f7c0801b 100755 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -55,6 +55,21 @@ class ComposeTests(ConnectingTestCase): s1 = s.as_string(self.conn) self.assertEqual(s1, "select '2016-12-31'::date;") + def test_compose_empty(self): + s = sql.compose("select foo;") + s1 = s.as_string(self.conn) + self.assertEqual(s1, "select foo;") + + def test_compose_badnargs(self): + self.assertRaises(ValueError, sql.compose, "select foo;", [10]) + self.assertRaises(ValueError, sql.compose, "select %s;") + self.assertRaises(ValueError, sql.compose, "select %s;", []) + self.assertRaises(ValueError, sql.compose, "select %s;", [10, 20]) + + def test_compose_bad_args_type(self): + self.assertRaises(TypeError, sql.compose, "select %s;", {'a': 10}) + self.assertRaises(TypeError, sql.compose, "select %(x)s;", [10]) + def test_must_be_adaptable(self): class Foo(object): pass