diff --git a/lib/extras.py b/lib/extras.py index 2636655a..2d264025 100644 --- a/lib/extras.py +++ b/lib/extras.py @@ -1220,6 +1220,6 @@ def execute_values(cur, sql, argslist, template=None, page_size=100): if template is None: template = '(%s)' % ','.join(['%s'] * len(page[0])) values = b",".join(cur.mogrify(template, args) for args in page) - if isinstance(values, bytes) and _sys.version_info[0] > 2: + if isinstance(values, bytes): values = values.decode(_ext.encodings[cur.connection.encoding]) cur.execute(sql % (values,)) diff --git a/tests/test_types_extras.py b/tests/test_types_extras.py index 952208c5..a584c868 100755 --- a/tests/test_types_extras.py +++ b/tests/test_types_extras.py @@ -1770,8 +1770,8 @@ class TestFastExecute(ConnectingTestCase): def setUp(self): super(TestFastExecute, self).setUp() cur = self.conn.cursor() - cur.execute( - "create table testfast (id serial primary key, date date, val int)") + cur.execute("""create table testfast ( + id serial primary key, date date, val int, data text)""") def test_paginate(self): def pag(seq): @@ -1834,6 +1834,32 @@ class TestFastExecute(ConnectingTestCase): cur.execute("select id, val from testfast order by id") self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)]) + def test_execute_batch_unicode(self): + cur = self.conn.cursor() + ext.register_type(ext.UNICODE, cur) + snowman = u"\u2603" + + # unicode in statement + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman, + [(1, 'x')]) + cur.execute("select id, data from testfast where id = 1") + self.assertEqual(cur.fetchone(), (1, 'x')) + + # unicode in data + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, data) values (%s, %s)", + [(2, snowman)]) + cur.execute("select id, data from testfast where id = 2") + self.assertEqual(cur.fetchone(), (2, snowman)) + + # unicode in both + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman, + [(3, snowman)]) + cur.execute("select id, data from testfast where id = 3") + self.assertEqual(cur.fetchone(), (3, snowman)) + def test_execute_values_empty(self): cur = self.conn.cursor() psycopg2.extras.execute_values(cur, @@ -1891,6 +1917,32 @@ class TestFastExecute(ConnectingTestCase): cur.execute("select id, val from testfast order by id") self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)]) + def test_execute_values_unicode(self): + cur = self.conn.cursor() + ext.register_type(ext.UNICODE, cur) + snowman = u"\u2603" + + # unicode in statement + psycopg2.extras.execute_values(cur, + "insert into testfast (id, data) values %%s -- %s" % snowman, + [(1, 'x')]) + cur.execute("select id, data from testfast where id = 1") + self.assertEqual(cur.fetchone(), (1, 'x')) + + # unicode in data + psycopg2.extras.execute_values(cur, + "insert into testfast (id, data) values %s", + [(2, snowman)]) + cur.execute("select id, data from testfast where id = 2") + self.assertEqual(cur.fetchone(), (2, snowman)) + + # unicode in both + psycopg2.extras.execute_values(cur, + "insert into testfast (id, data) values %%s -- %s" % snowman, + [(3, snowman)]) + cur.execute("select id, data from testfast where id = 3") + self.assertEqual(cur.fetchone(), (3, snowman)) + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__)