Avoid an useless encode/decode roundtrip in execute_values()

Tests moved into a separate module.
This commit is contained in:
Daniele Varrazzo 2017-02-02 17:29:17 +00:00
parent d2fdc5ca9f
commit dc1b4fff90
4 changed files with 283 additions and 182 deletions

View File

@ -1232,10 +1232,50 @@ def execute_values(cur, sql, argslist, template=None, page_size=100):
[(1, 20, 3), (4, 50, 6), (7, 8, 9)])
'''
# we can't just use sql % vals because vals is bytes: if sql is bytes
# there will be some decoding error because of stupid codec used, and Py3
# doesn't implement % on bytes.
if not isinstance(sql, bytes):
sql = sql.encode(_ext.encodings[cur.connection.encoding])
pre, post = _split_sql(sql)
for page in _paginate(argslist, page_size=page_size):
if template is None:
template = '(%s)' % ','.join(['%s'] * len(page[0]))
values = b",".join(cur.mogrify(template, args) for args in page)
if isinstance(values, bytes):
values = values.decode(_ext.encodings[cur.connection.encoding])
cur.execute(sql % (values,))
parts = [pre]
for args in page:
parts.append(cur.mogrify(template, args))
parts.append(b',')
parts[-1] = post
cur.execute(b''.join(parts))
def _split_sql(sql):
"""Split *sql* on a single ``%s`` placeholder.
Return a (pre, post) pair around the ``%s``, with ``%%`` -> ``%`` replacement.
"""
curr = pre = []
post = []
tokens = _re.split(br'(%.)', sql)
for token in tokens:
if len(token) != 2 or token[:1] != b'%':
curr.append(token)
continue
if token[1:] == b's':
if curr is pre:
curr = post
else:
raise ValueError(
"the query contains more than one '%s' placeholder")
elif token[1:] == b'%':
curr.append(b'%')
else:
raise ValueError("unsupported format character: '%s'"
% token[1:].decode('ascii', 'replace'))
if curr is pre:
raise ValueError("the query doesn't contain any '%s' placeholder")
return b''.join(pre), b''.join(post)

View File

@ -37,6 +37,7 @@ import test_cursor
import test_dates
import test_errcodes
import test_extras_dictcursor
import test_fast_executemany
import test_green
import test_ipaddress
import test_lobject
@ -74,6 +75,7 @@ def test_suite():
suite.addTest(test_dates.test_suite())
suite.addTest(test_errcodes.test_suite())
suite.addTest(test_extras_dictcursor.test_suite())
suite.addTest(test_fast_executemany.test_suite())
suite.addTest(test_green.test_suite())
suite.addTest(test_ipaddress.test_suite())
suite.addTest(test_lobject.test_suite())

237
tests/test_fast_executemany.py Executable file
View File

@ -0,0 +1,237 @@
#!/usr/bin/env python
#
# test_fast_executemany.py - tests for fast executemany implementations
#
# Copyright (C) 2017 Daniele Varrazzo <daniele.varrazzo@gmail.com>
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import unittest
from datetime import date
from testutils import ConnectingTestCase
import psycopg2
import psycopg2.extras
import psycopg2.extensions as ext
class TestPaginate(unittest.TestCase):
def test_paginate(self):
def pag(seq):
return psycopg2.extras._paginate(seq, 100)
self.assertEqual(list(pag([])), [])
self.assertEqual(list(pag([1])), [[1]])
self.assertEqual(list(pag(range(99))), [list(range(99))])
self.assertEqual(list(pag(range(100))), [list(range(100))])
self.assertEqual(list(pag(range(101))), [list(range(100)), [100]])
self.assertEqual(
list(pag(range(200))), [list(range(100)), list(range(100, 200))])
self.assertEqual(
list(pag(range(1000))),
[list(range(i * 100, (i + 1) * 100)) for i in range(10)])
class FastExecuteTestMixin(object):
def setUp(self):
super(FastExecuteTestMixin, self).setUp()
cur = self.conn.cursor()
cur.execute("""create table testfast (
id serial primary key, date date, val int, data text)""")
class TestExecuteBatch(FastExecuteTestMixin, ConnectingTestCase):
def test_empty(self):
cur = self.conn.cursor()
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, val) values (%s, %s)",
[])
cur.execute("select * from testfast order by id")
self.assertEqual(cur.fetchall(), [])
def test_one(self):
cur = self.conn.cursor()
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, val) values (%s, %s)",
iter([(1, 10)]))
cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(1, 10)])
def test_tuples(self):
cur = self.conn.cursor()
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, date, val) values (%s, %s, %s)",
((i, date(2017, 1, i + 1), i * 10) for i in range(10)))
cur.execute("select id, date, val from testfast order by id")
self.assertEqual(cur.fetchall(),
[(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
def test_many(self):
cur = self.conn.cursor()
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, val) values (%s, %s)",
((i, i * 10) for i in range(1000)))
cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
def test_pages(self):
cur = self.conn.cursor()
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, val) values (%s, %s)",
((i, i * 10) for i in range(25)),
page_size=10)
# last command was 5 statements
self.assertEqual(sum(c == u';' for c in cur.query.decode('ascii')), 4)
cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)])
def test_unicode(self):
cur = self.conn.cursor()
ext.register_type(ext.UNICODE, cur)
snowman = u"\u2603"
# unicode in statement
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman,
[(1, 'x')])
cur.execute("select id, data from testfast where id = 1")
self.assertEqual(cur.fetchone(), (1, 'x'))
# unicode in data
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, data) values (%s, %s)",
[(2, snowman)])
cur.execute("select id, data from testfast where id = 2")
self.assertEqual(cur.fetchone(), (2, snowman))
# unicode in both
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman,
[(3, snowman)])
cur.execute("select id, data from testfast where id = 3")
self.assertEqual(cur.fetchone(), (3, snowman))
class TestExecuteValuse(FastExecuteTestMixin, ConnectingTestCase):
def test_empty(self):
cur = self.conn.cursor()
psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s",
[])
cur.execute("select * from testfast order by id")
self.assertEqual(cur.fetchall(), [])
def test_one(self):
cur = self.conn.cursor()
psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s",
iter([(1, 10)]))
cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(1, 10)])
def test_tuples(self):
cur = self.conn.cursor()
psycopg2.extras.execute_values(cur,
"insert into testfast (id, date, val) values %s",
((i, date(2017, 1, i + 1), i * 10) for i in range(10)))
cur.execute("select id, date, val from testfast order by id")
self.assertEqual(cur.fetchall(),
[(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
def test_dicts(self):
cur = self.conn.cursor()
psycopg2.extras.execute_values(cur,
"insert into testfast (id, date, val) values %s",
(dict(id=i, date=date(2017, 1, i + 1), val=i * 10, foo="bar")
for i in range(10)),
template='(%(id)s, %(date)s, %(val)s)')
cur.execute("select id, date, val from testfast order by id")
self.assertEqual(cur.fetchall(),
[(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
def test_many(self):
cur = self.conn.cursor()
psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s",
((i, i * 10) for i in range(1000)))
cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
def test_pages(self):
cur = self.conn.cursor()
psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s",
((i, i * 10) for i in range(25)),
page_size=10)
# last statement was 5 tuples (one parens is for the fields list)
self.assertEqual(sum(c == '(' for c in cur.query.decode('ascii')), 6)
cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)])
def test_unicode(self):
cur = self.conn.cursor()
ext.register_type(ext.UNICODE, cur)
snowman = u"\u2603"
# unicode in statement
psycopg2.extras.execute_values(cur,
"insert into testfast (id, data) values %%s -- %s" % snowman,
[(1, 'x')])
cur.execute("select id, data from testfast where id = 1")
self.assertEqual(cur.fetchone(), (1, 'x'))
# unicode in data
psycopg2.extras.execute_values(cur,
"insert into testfast (id, data) values %s",
[(2, snowman)])
cur.execute("select id, data from testfast where id = 2")
self.assertEqual(cur.fetchone(), (2, snowman))
# unicode in both
psycopg2.extras.execute_values(cur,
"insert into testfast (id, data) values %%s -- %s" % snowman,
[(3, snowman)])
cur.execute("select id, data from testfast where id = 3")
self.assertEqual(cur.fetchone(), (3, snowman))
def test_invalid_sql(self):
cur = self.conn.cursor()
self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
"insert", [])
self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
"insert %s and %s", [])
self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
"insert %f", [])
self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
"insert %f %s", [])
def test_percent_escape(self):
cur = self.conn.cursor()
psycopg2.extras.execute_values(cur,
"insert into testfast (id, data) values %s -- a%%b",
[(1, 'hi')])
self.assert_(b'a%%b' not in cur.query)
self.assert_(b'a%b' in cur.query)
cur.execute("select id, data from testfast")
self.assertEqual(cur.fetchall(), [(1, 'hi')])
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

View File

@ -1766,184 +1766,6 @@ class RangeCasterTestCase(ConnectingTestCase):
decorate_all_tests(RangeCasterTestCase, skip_if_no_range)
class TestFastExecute(ConnectingTestCase):
def setUp(self):
super(TestFastExecute, self).setUp()
cur = self.conn.cursor()
cur.execute("""create table testfast (
id serial primary key, date date, val int, data text)""")
def test_paginate(self):
def pag(seq):
return psycopg2.extras._paginate(seq, 100)
self.assertEqual(list(pag([])), [])
self.assertEqual(list(pag([1])), [[1]])
self.assertEqual(list(pag(range(99))), [list(range(99))])
self.assertEqual(list(pag(range(100))), [list(range(100))])
self.assertEqual(list(pag(range(101))), [list(range(100)), [100]])
self.assertEqual(
list(pag(range(200))), [list(range(100)), list(range(100, 200))])
self.assertEqual(
list(pag(range(1000))),
[list(range(i * 100, (i + 1) * 100)) for i in range(10)])
def test_execute_batch_empty(self):
cur = self.conn.cursor()
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, val) values (%s, %s)",
[])
cur.execute("select * from testfast order by id")
self.assertEqual(cur.fetchall(), [])
def test_execute_batch_one(self):
cur = self.conn.cursor()
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, val) values (%s, %s)",
iter([(1, 10)]))
cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(1, 10)])
def test_execute_batch_tuples(self):
cur = self.conn.cursor()
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, date, val) values (%s, %s, %s)",
((i, date(2017, 1, i + 1), i * 10) for i in range(10)))
cur.execute("select id, date, val from testfast order by id")
self.assertEqual(cur.fetchall(),
[(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
def test_execute_batch_many(self):
cur = self.conn.cursor()
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, val) values (%s, %s)",
((i, i * 10) for i in range(1000)))
cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
def test_execute_batch_pages(self):
cur = self.conn.cursor()
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, val) values (%s, %s)",
((i, i * 10) for i in range(25)),
page_size=10)
# last command was 5 statements
self.assertEqual(sum(c == u';' for c in cur.query.decode('ascii')), 4)
cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)])
def test_execute_batch_unicode(self):
cur = self.conn.cursor()
ext.register_type(ext.UNICODE, cur)
snowman = u"\u2603"
# unicode in statement
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman,
[(1, 'x')])
cur.execute("select id, data from testfast where id = 1")
self.assertEqual(cur.fetchone(), (1, 'x'))
# unicode in data
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, data) values (%s, %s)",
[(2, snowman)])
cur.execute("select id, data from testfast where id = 2")
self.assertEqual(cur.fetchone(), (2, snowman))
# unicode in both
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman,
[(3, snowman)])
cur.execute("select id, data from testfast where id = 3")
self.assertEqual(cur.fetchone(), (3, snowman))
def test_execute_values_empty(self):
cur = self.conn.cursor()
psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s",
[])
cur.execute("select * from testfast order by id")
self.assertEqual(cur.fetchall(), [])
def test_execute_values_one(self):
cur = self.conn.cursor()
psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s",
iter([(1, 10)]))
cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(1, 10)])
def test_execute_values_tuples(self):
cur = self.conn.cursor()
psycopg2.extras.execute_values(cur,
"insert into testfast (id, date, val) values %s",
((i, date(2017, 1, i + 1), i * 10) for i in range(10)))
cur.execute("select id, date, val from testfast order by id")
self.assertEqual(cur.fetchall(),
[(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
def test_execute_values_dicts(self):
cur = self.conn.cursor()
psycopg2.extras.execute_values(cur,
"insert into testfast (id, date, val) values %s",
(dict(id=i, date=date(2017, 1, i + 1), val=i * 10, foo="bar")
for i in range(10)),
template='(%(id)s, %(date)s, %(val)s)')
cur.execute("select id, date, val from testfast order by id")
self.assertEqual(cur.fetchall(),
[(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
def test_execute_values_many(self):
cur = self.conn.cursor()
psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s",
((i, i * 10) for i in range(1000)))
cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
def test_execute_values_pages(self):
cur = self.conn.cursor()
psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s",
((i, i * 10) for i in range(25)),
page_size=10)
# last statement was 5 tuples (one parens is for the fields list)
self.assertEqual(sum(c == '(' for c in cur.query.decode('ascii')), 6)
cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)])
def test_execute_values_unicode(self):
cur = self.conn.cursor()
ext.register_type(ext.UNICODE, cur)
snowman = u"\u2603"
# unicode in statement
psycopg2.extras.execute_values(cur,
"insert into testfast (id, data) values %%s -- %s" % snowman,
[(1, 'x')])
cur.execute("select id, data from testfast where id = 1")
self.assertEqual(cur.fetchone(), (1, 'x'))
# unicode in data
psycopg2.extras.execute_values(cur,
"insert into testfast (id, data) values %s",
[(2, snowman)])
cur.execute("select id, data from testfast where id = 2")
self.assertEqual(cur.fetchone(), (2, snowman))
# unicode in both
psycopg2.extras.execute_values(cur,
"insert into testfast (id, data) values %%s -- %s" % snowman,
[(3, snowman)])
cur.execute("select id, data from testfast where id = 3")
self.assertEqual(cur.fetchone(), (3, snowman))
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)