Fixed sql.compose with no args and added tests

This commit is contained in:
Daniele Varrazzo 2017-01-01 06:26:54 +01:00
parent 8c020ca47a
commit ad2643266f
2 changed files with 19 additions and 4 deletions

View File

@ -204,7 +204,7 @@ re_compose = re.compile("""
""", re.VERBOSE) """, re.VERBOSE)
def compose(sql, args=None): def compose(sql, args=()):
phs = list(re_compose.finditer(sql)) phs = list(re_compose.finditer(sql))
# check placeholders consistent # check placeholders consistent
@ -240,8 +240,8 @@ def compose(sql, args=None):
return _compose_map(sql, phs, args) return _compose_map(sql, phs, args)
else: else:
if not isinstance(args, collections.Sequence) and args: if isinstance(args, collections.Sequence) and args:
raise TypeError( raise ValueError(
"the sql string expects no value, got %s instead" % len(args)) "the sql string expects no value, got %s instead" % len(args))
# If args are a mapping, no placeholder is an acceptable case # If args are a mapping, no placeholder is an acceptable case
@ -267,7 +267,7 @@ def _compose_seq(sql, phs, args):
if phs: if phs:
rv.append(SQL(sql[phs[-1].end():])) rv.append(SQL(sql[phs[-1].end():]))
else: else:
rv.append(sql) rv.append(SQL(sql))
return Composed(rv) return Composed(rv)

View File

@ -55,6 +55,21 @@ class ComposeTests(ConnectingTestCase):
s1 = s.as_string(self.conn) s1 = s.as_string(self.conn)
self.assertEqual(s1, "select '2016-12-31'::date;") self.assertEqual(s1, "select '2016-12-31'::date;")
def test_compose_empty(self):
s = sql.compose("select foo;")
s1 = s.as_string(self.conn)
self.assertEqual(s1, "select foo;")
def test_compose_badnargs(self):
self.assertRaises(ValueError, sql.compose, "select foo;", [10])
self.assertRaises(ValueError, sql.compose, "select %s;")
self.assertRaises(ValueError, sql.compose, "select %s;", [])
self.assertRaises(ValueError, sql.compose, "select %s;", [10, 20])
def test_compose_bad_args_type(self):
self.assertRaises(TypeError, sql.compose, "select %s;", {'a': 10})
self.assertRaises(TypeError, sql.compose, "select %(x)s;", [10])
def test_must_be_adaptable(self): def test_must_be_adaptable(self):
class Foo(object): class Foo(object):
pass pass