From dc1b4fff9001964c719e3f4471cc5a6fe6533e3a Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Thu, 2 Feb 2017 17:29:17 +0000 Subject: [PATCH] Avoid an useless encode/decode roundtrip in execute_values() Tests moved into a separate module. --- lib/extras.py | 48 ++++++- tests/__init__.py | 2 + tests/test_fast_executemany.py | 237 +++++++++++++++++++++++++++++++++ tests/test_types_extras.py | 178 ------------------------- 4 files changed, 283 insertions(+), 182 deletions(-) create mode 100755 tests/test_fast_executemany.py diff --git a/lib/extras.py b/lib/extras.py index 1aad3d1d..80034e6f 100644 --- a/lib/extras.py +++ b/lib/extras.py @@ -1232,10 +1232,50 @@ def execute_values(cur, sql, argslist, template=None, page_size=100): [(1, 20, 3), (4, 50, 6), (7, 8, 9)]) ''' + # we can't just use sql % vals because vals is bytes: if sql is bytes + # there will be some decoding error because of stupid codec used, and Py3 + # doesn't implement % on bytes. + if not isinstance(sql, bytes): + sql = sql.encode(_ext.encodings[cur.connection.encoding]) + pre, post = _split_sql(sql) + for page in _paginate(argslist, page_size=page_size): if template is None: template = '(%s)' % ','.join(['%s'] * len(page[0])) - values = b",".join(cur.mogrify(template, args) for args in page) - if isinstance(values, bytes): - values = values.decode(_ext.encodings[cur.connection.encoding]) - cur.execute(sql % (values,)) + parts = [pre] + for args in page: + parts.append(cur.mogrify(template, args)) + parts.append(b',') + parts[-1] = post + cur.execute(b''.join(parts)) + + +def _split_sql(sql): + """Split *sql* on a single ``%s`` placeholder. + + Return a (pre, post) pair around the ``%s``, with ``%%`` -> ``%`` replacement. + """ + curr = pre = [] + post = [] + tokens = _re.split(br'(%.)', sql) + for token in tokens: + if len(token) != 2 or token[:1] != b'%': + curr.append(token) + continue + + if token[1:] == b's': + if curr is pre: + curr = post + else: + raise ValueError( + "the query contains more than one '%s' placeholder") + elif token[1:] == b'%': + curr.append(b'%') + else: + raise ValueError("unsupported format character: '%s'" + % token[1:].decode('ascii', 'replace')) + + if curr is pre: + raise ValueError("the query doesn't contain any '%s' placeholder") + + return b''.join(pre), b''.join(post) diff --git a/tests/__init__.py b/tests/__init__.py index 1a240994..35837e82 100755 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -37,6 +37,7 @@ import test_cursor import test_dates import test_errcodes import test_extras_dictcursor +import test_fast_executemany import test_green import test_ipaddress import test_lobject @@ -74,6 +75,7 @@ def test_suite(): suite.addTest(test_dates.test_suite()) suite.addTest(test_errcodes.test_suite()) suite.addTest(test_extras_dictcursor.test_suite()) + suite.addTest(test_fast_executemany.test_suite()) suite.addTest(test_green.test_suite()) suite.addTest(test_ipaddress.test_suite()) suite.addTest(test_lobject.test_suite()) diff --git a/tests/test_fast_executemany.py b/tests/test_fast_executemany.py new file mode 100755 index 00000000..92222748 --- /dev/null +++ b/tests/test_fast_executemany.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python +# +# test_fast_executemany.py - tests for fast executemany implementations +# +# Copyright (C) 2017 Daniele Varrazzo +# +# psycopg2 is free software: you can redistribute it and/or modify it +# under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# psycopg2 is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +import unittest +from datetime import date + +from testutils import ConnectingTestCase + +import psycopg2 +import psycopg2.extras +import psycopg2.extensions as ext + + +class TestPaginate(unittest.TestCase): + def test_paginate(self): + def pag(seq): + return psycopg2.extras._paginate(seq, 100) + + self.assertEqual(list(pag([])), []) + self.assertEqual(list(pag([1])), [[1]]) + self.assertEqual(list(pag(range(99))), [list(range(99))]) + self.assertEqual(list(pag(range(100))), [list(range(100))]) + self.assertEqual(list(pag(range(101))), [list(range(100)), [100]]) + self.assertEqual( + list(pag(range(200))), [list(range(100)), list(range(100, 200))]) + self.assertEqual( + list(pag(range(1000))), + [list(range(i * 100, (i + 1) * 100)) for i in range(10)]) + + +class FastExecuteTestMixin(object): + def setUp(self): + super(FastExecuteTestMixin, self).setUp() + cur = self.conn.cursor() + cur.execute("""create table testfast ( + id serial primary key, date date, val int, data text)""") + + +class TestExecuteBatch(FastExecuteTestMixin, ConnectingTestCase): + def test_empty(self): + cur = self.conn.cursor() + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, val) values (%s, %s)", + []) + cur.execute("select * from testfast order by id") + self.assertEqual(cur.fetchall(), []) + + def test_one(self): + cur = self.conn.cursor() + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, val) values (%s, %s)", + iter([(1, 10)])) + cur.execute("select id, val from testfast order by id") + self.assertEqual(cur.fetchall(), [(1, 10)]) + + def test_tuples(self): + cur = self.conn.cursor() + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, date, val) values (%s, %s, %s)", + ((i, date(2017, 1, i + 1), i * 10) for i in range(10))) + cur.execute("select id, date, val from testfast order by id") + self.assertEqual(cur.fetchall(), + [(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) + + def test_many(self): + cur = self.conn.cursor() + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, val) values (%s, %s)", + ((i, i * 10) for i in range(1000))) + cur.execute("select id, val from testfast order by id") + self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) + + def test_pages(self): + cur = self.conn.cursor() + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, val) values (%s, %s)", + ((i, i * 10) for i in range(25)), + page_size=10) + + # last command was 5 statements + self.assertEqual(sum(c == u';' for c in cur.query.decode('ascii')), 4) + + cur.execute("select id, val from testfast order by id") + self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)]) + + def test_unicode(self): + cur = self.conn.cursor() + ext.register_type(ext.UNICODE, cur) + snowman = u"\u2603" + + # unicode in statement + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman, + [(1, 'x')]) + cur.execute("select id, data from testfast where id = 1") + self.assertEqual(cur.fetchone(), (1, 'x')) + + # unicode in data + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, data) values (%s, %s)", + [(2, snowman)]) + cur.execute("select id, data from testfast where id = 2") + self.assertEqual(cur.fetchone(), (2, snowman)) + + # unicode in both + psycopg2.extras.execute_batch(cur, + "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman, + [(3, snowman)]) + cur.execute("select id, data from testfast where id = 3") + self.assertEqual(cur.fetchone(), (3, snowman)) + + +class TestExecuteValuse(FastExecuteTestMixin, ConnectingTestCase): + def test_empty(self): + cur = self.conn.cursor() + psycopg2.extras.execute_values(cur, + "insert into testfast (id, val) values %s", + []) + cur.execute("select * from testfast order by id") + self.assertEqual(cur.fetchall(), []) + + def test_one(self): + cur = self.conn.cursor() + psycopg2.extras.execute_values(cur, + "insert into testfast (id, val) values %s", + iter([(1, 10)])) + cur.execute("select id, val from testfast order by id") + self.assertEqual(cur.fetchall(), [(1, 10)]) + + def test_tuples(self): + cur = self.conn.cursor() + psycopg2.extras.execute_values(cur, + "insert into testfast (id, date, val) values %s", + ((i, date(2017, 1, i + 1), i * 10) for i in range(10))) + cur.execute("select id, date, val from testfast order by id") + self.assertEqual(cur.fetchall(), + [(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) + + def test_dicts(self): + cur = self.conn.cursor() + psycopg2.extras.execute_values(cur, + "insert into testfast (id, date, val) values %s", + (dict(id=i, date=date(2017, 1, i + 1), val=i * 10, foo="bar") + for i in range(10)), + template='(%(id)s, %(date)s, %(val)s)') + cur.execute("select id, date, val from testfast order by id") + self.assertEqual(cur.fetchall(), + [(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) + + def test_many(self): + cur = self.conn.cursor() + psycopg2.extras.execute_values(cur, + "insert into testfast (id, val) values %s", + ((i, i * 10) for i in range(1000))) + cur.execute("select id, val from testfast order by id") + self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) + + def test_pages(self): + cur = self.conn.cursor() + psycopg2.extras.execute_values(cur, + "insert into testfast (id, val) values %s", + ((i, i * 10) for i in range(25)), + page_size=10) + + # last statement was 5 tuples (one parens is for the fields list) + self.assertEqual(sum(c == '(' for c in cur.query.decode('ascii')), 6) + + cur.execute("select id, val from testfast order by id") + self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)]) + + def test_unicode(self): + cur = self.conn.cursor() + ext.register_type(ext.UNICODE, cur) + snowman = u"\u2603" + + # unicode in statement + psycopg2.extras.execute_values(cur, + "insert into testfast (id, data) values %%s -- %s" % snowman, + [(1, 'x')]) + cur.execute("select id, data from testfast where id = 1") + self.assertEqual(cur.fetchone(), (1, 'x')) + + # unicode in data + psycopg2.extras.execute_values(cur, + "insert into testfast (id, data) values %s", + [(2, snowman)]) + cur.execute("select id, data from testfast where id = 2") + self.assertEqual(cur.fetchone(), (2, snowman)) + + # unicode in both + psycopg2.extras.execute_values(cur, + "insert into testfast (id, data) values %%s -- %s" % snowman, + [(3, snowman)]) + cur.execute("select id, data from testfast where id = 3") + self.assertEqual(cur.fetchone(), (3, snowman)) + + def test_invalid_sql(self): + cur = self.conn.cursor() + self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, + "insert", []) + self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, + "insert %s and %s", []) + self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, + "insert %f", []) + self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, + "insert %f %s", []) + + def test_percent_escape(self): + cur = self.conn.cursor() + psycopg2.extras.execute_values(cur, + "insert into testfast (id, data) values %s -- a%%b", + [(1, 'hi')]) + self.assert_(b'a%%b' not in cur.query) + self.assert_(b'a%b' in cur.query) + + cur.execute("select id, data from testfast") + self.assertEqual(cur.fetchall(), [(1, 'hi')]) + + +def test_suite(): + return unittest.TestLoader().loadTestsFromName(__name__) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_types_extras.py b/tests/test_types_extras.py index a584c868..8e615616 100755 --- a/tests/test_types_extras.py +++ b/tests/test_types_extras.py @@ -1766,184 +1766,6 @@ class RangeCasterTestCase(ConnectingTestCase): decorate_all_tests(RangeCasterTestCase, skip_if_no_range) -class TestFastExecute(ConnectingTestCase): - def setUp(self): - super(TestFastExecute, self).setUp() - cur = self.conn.cursor() - cur.execute("""create table testfast ( - id serial primary key, date date, val int, data text)""") - - def test_paginate(self): - def pag(seq): - return psycopg2.extras._paginate(seq, 100) - - self.assertEqual(list(pag([])), []) - self.assertEqual(list(pag([1])), [[1]]) - self.assertEqual(list(pag(range(99))), [list(range(99))]) - self.assertEqual(list(pag(range(100))), [list(range(100))]) - self.assertEqual(list(pag(range(101))), [list(range(100)), [100]]) - self.assertEqual( - list(pag(range(200))), [list(range(100)), list(range(100, 200))]) - self.assertEqual( - list(pag(range(1000))), - [list(range(i * 100, (i + 1) * 100)) for i in range(10)]) - - def test_execute_batch_empty(self): - cur = self.conn.cursor() - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, val) values (%s, %s)", - []) - cur.execute("select * from testfast order by id") - self.assertEqual(cur.fetchall(), []) - - def test_execute_batch_one(self): - cur = self.conn.cursor() - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, val) values (%s, %s)", - iter([(1, 10)])) - cur.execute("select id, val from testfast order by id") - self.assertEqual(cur.fetchall(), [(1, 10)]) - - def test_execute_batch_tuples(self): - cur = self.conn.cursor() - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, date, val) values (%s, %s, %s)", - ((i, date(2017, 1, i + 1), i * 10) for i in range(10))) - cur.execute("select id, date, val from testfast order by id") - self.assertEqual(cur.fetchall(), - [(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) - - def test_execute_batch_many(self): - cur = self.conn.cursor() - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, val) values (%s, %s)", - ((i, i * 10) for i in range(1000))) - cur.execute("select id, val from testfast order by id") - self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) - - def test_execute_batch_pages(self): - cur = self.conn.cursor() - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, val) values (%s, %s)", - ((i, i * 10) for i in range(25)), - page_size=10) - - # last command was 5 statements - self.assertEqual(sum(c == u';' for c in cur.query.decode('ascii')), 4) - - cur.execute("select id, val from testfast order by id") - self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)]) - - def test_execute_batch_unicode(self): - cur = self.conn.cursor() - ext.register_type(ext.UNICODE, cur) - snowman = u"\u2603" - - # unicode in statement - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman, - [(1, 'x')]) - cur.execute("select id, data from testfast where id = 1") - self.assertEqual(cur.fetchone(), (1, 'x')) - - # unicode in data - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, data) values (%s, %s)", - [(2, snowman)]) - cur.execute("select id, data from testfast where id = 2") - self.assertEqual(cur.fetchone(), (2, snowman)) - - # unicode in both - psycopg2.extras.execute_batch(cur, - "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman, - [(3, snowman)]) - cur.execute("select id, data from testfast where id = 3") - self.assertEqual(cur.fetchone(), (3, snowman)) - - def test_execute_values_empty(self): - cur = self.conn.cursor() - psycopg2.extras.execute_values(cur, - "insert into testfast (id, val) values %s", - []) - cur.execute("select * from testfast order by id") - self.assertEqual(cur.fetchall(), []) - - def test_execute_values_one(self): - cur = self.conn.cursor() - psycopg2.extras.execute_values(cur, - "insert into testfast (id, val) values %s", - iter([(1, 10)])) - cur.execute("select id, val from testfast order by id") - self.assertEqual(cur.fetchall(), [(1, 10)]) - - def test_execute_values_tuples(self): - cur = self.conn.cursor() - psycopg2.extras.execute_values(cur, - "insert into testfast (id, date, val) values %s", - ((i, date(2017, 1, i + 1), i * 10) for i in range(10))) - cur.execute("select id, date, val from testfast order by id") - self.assertEqual(cur.fetchall(), - [(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) - - def test_execute_values_dicts(self): - cur = self.conn.cursor() - psycopg2.extras.execute_values(cur, - "insert into testfast (id, date, val) values %s", - (dict(id=i, date=date(2017, 1, i + 1), val=i * 10, foo="bar") - for i in range(10)), - template='(%(id)s, %(date)s, %(val)s)') - cur.execute("select id, date, val from testfast order by id") - self.assertEqual(cur.fetchall(), - [(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) - - def test_execute_values_many(self): - cur = self.conn.cursor() - psycopg2.extras.execute_values(cur, - "insert into testfast (id, val) values %s", - ((i, i * 10) for i in range(1000))) - cur.execute("select id, val from testfast order by id") - self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) - - def test_execute_values_pages(self): - cur = self.conn.cursor() - psycopg2.extras.execute_values(cur, - "insert into testfast (id, val) values %s", - ((i, i * 10) for i in range(25)), - page_size=10) - - # last statement was 5 tuples (one parens is for the fields list) - self.assertEqual(sum(c == '(' for c in cur.query.decode('ascii')), 6) - - cur.execute("select id, val from testfast order by id") - self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)]) - - def test_execute_values_unicode(self): - cur = self.conn.cursor() - ext.register_type(ext.UNICODE, cur) - snowman = u"\u2603" - - # unicode in statement - psycopg2.extras.execute_values(cur, - "insert into testfast (id, data) values %%s -- %s" % snowman, - [(1, 'x')]) - cur.execute("select id, data from testfast where id = 1") - self.assertEqual(cur.fetchone(), (1, 'x')) - - # unicode in data - psycopg2.extras.execute_values(cur, - "insert into testfast (id, data) values %s", - [(2, snowman)]) - cur.execute("select id, data from testfast where id = 2") - self.assertEqual(cur.fetchone(), (2, snowman)) - - # unicode in both - psycopg2.extras.execute_values(cur, - "insert into testfast (id, data) values %%s -- %s" % snowman, - [(3, snowman)]) - cur.execute("select id, data from testfast where id = 3") - self.assertEqual(cur.fetchone(), (3, snowman)) - - def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__)