#!/usr/bin/env python # # test_fast_executemany.py - tests for fast executemany implementations # # Copyright (C) 2017 Daniele Varrazzo # # psycopg2 is free software: you can redistribute it and/or modify it # under the terms of the GNU Lesser General Public License as published # by the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # psycopg2 is distributed in the hope that it will be useful, but WITHOUT # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or # FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public # License for more details. from datetime import date import testutils from testutils import unittest import psycopg2 import psycopg2.extras import psycopg2.extensions as ext class TestPaginate(unittest.TestCase): def test_paginate(self): def pag(seq): return psycopg2.extras._paginate(seq, 100) self.assertEqual(list(pag([])), []) self.assertEqual(list(pag([1])), [[1]]) self.assertEqual(list(pag(range(99))), [list(range(99))]) self.assertEqual(list(pag(range(100))), [list(range(100))]) self.assertEqual(list(pag(range(101))), [list(range(100)), [100]]) self.assertEqual( list(pag(range(200))), [list(range(100)), list(range(100, 200))]) self.assertEqual( list(pag(range(1000))), [list(range(i * 100, (i + 1) * 100)) for i in range(10)]) class FastExecuteTestMixin(object): def setUp(self): super(FastExecuteTestMixin, self).setUp() cur = self.conn.cursor() cur.execute("""create table testfast ( id serial primary key, date date, val int, data text)""") class TestExecuteBatch(FastExecuteTestMixin, testutils.ConnectingTestCase): def test_empty(self): cur = self.conn.cursor() psycopg2.extras.execute_batch(cur, "insert into testfast (id, val) values (%s, %s)", []) cur.execute("select * from testfast order by id") self.assertEqual(cur.fetchall(), []) def test_one(self): cur = self.conn.cursor() psycopg2.extras.execute_batch(cur, "insert into testfast (id, val) values (%s, %s)", iter([(1, 10)])) cur.execute("select id, val from testfast order by id") self.assertEqual(cur.fetchall(), [(1, 10)]) def test_tuples(self): cur = self.conn.cursor() psycopg2.extras.execute_batch(cur, "insert into testfast (id, date, val) values (%s, %s, %s)", ((i, date(2017, 1, i + 1), i * 10) for i in range(10))) cur.execute("select id, date, val from testfast order by id") self.assertEqual(cur.fetchall(), [(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) def test_many(self): cur = self.conn.cursor() psycopg2.extras.execute_batch(cur, "insert into testfast (id, val) values (%s, %s)", ((i, i * 10) for i in range(1000))) cur.execute("select id, val from testfast order by id") self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) def test_composed(self): from psycopg2 import sql cur = self.conn.cursor() psycopg2.extras.execute_batch(cur, sql.SQL("insert into {0} (id, val) values (%s, %s)") .format(sql.Identifier('testfast')), ((i, i * 10) for i in range(1000))) cur.execute("select id, val from testfast order by id") self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) def test_pages(self): cur = self.conn.cursor() psycopg2.extras.execute_batch(cur, "insert into testfast (id, val) values (%s, %s)", ((i, i * 10) for i in range(25)), page_size=10) # last command was 5 statements self.assertEqual(sum(c == u';' for c in cur.query.decode('ascii')), 4) cur.execute("select id, val from testfast order by id") self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)]) @testutils.skip_before_postgres(8, 0) def test_unicode(self): cur = self.conn.cursor() ext.register_type(ext.UNICODE, cur) snowman = u"\u2603" # unicode in statement psycopg2.extras.execute_batch(cur, "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman, [(1, 'x')]) cur.execute("select id, data from testfast where id = 1") self.assertEqual(cur.fetchone(), (1, 'x')) # unicode in data psycopg2.extras.execute_batch(cur, "insert into testfast (id, data) values (%s, %s)", [(2, snowman)]) cur.execute("select id, data from testfast where id = 2") self.assertEqual(cur.fetchone(), (2, snowman)) # unicode in both psycopg2.extras.execute_batch(cur, "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman, [(3, snowman)]) cur.execute("select id, data from testfast where id = 3") self.assertEqual(cur.fetchone(), (3, snowman)) class TestExecuteValues(FastExecuteTestMixin, testutils.ConnectingTestCase): def test_empty(self): cur = self.conn.cursor() psycopg2.extras.execute_values(cur, "insert into testfast (id, val) values %s", []) cur.execute("select * from testfast order by id") self.assertEqual(cur.fetchall(), []) def test_one(self): cur = self.conn.cursor() psycopg2.extras.execute_values(cur, "insert into testfast (id, val) values %s", iter([(1, 10)])) cur.execute("select id, val from testfast order by id") self.assertEqual(cur.fetchall(), [(1, 10)]) def test_tuples(self): cur = self.conn.cursor() psycopg2.extras.execute_values(cur, "insert into testfast (id, date, val) values %s", ((i, date(2017, 1, i + 1), i * 10) for i in range(10))) cur.execute("select id, date, val from testfast order by id") self.assertEqual(cur.fetchall(), [(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) def test_dicts(self): cur = self.conn.cursor() psycopg2.extras.execute_values(cur, "insert into testfast (id, date, val) values %s", (dict(id=i, date=date(2017, 1, i + 1), val=i * 10, foo="bar") for i in range(10)), template='(%(id)s, %(date)s, %(val)s)') cur.execute("select id, date, val from testfast order by id") self.assertEqual(cur.fetchall(), [(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) def test_many(self): cur = self.conn.cursor() psycopg2.extras.execute_values(cur, "insert into testfast (id, val) values %s", ((i, i * 10) for i in range(1000))) cur.execute("select id, val from testfast order by id") self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) def test_composed(self): from psycopg2 import sql cur = self.conn.cursor() psycopg2.extras.execute_values(cur, sql.SQL("insert into {0} (id, val) values %s") .format(sql.Identifier('testfast')), ((i, i * 10) for i in range(1000))) cur.execute("select id, val from testfast order by id") self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) def test_pages(self): cur = self.conn.cursor() psycopg2.extras.execute_values(cur, "insert into testfast (id, val) values %s", ((i, i * 10) for i in range(25)), page_size=10) # last statement was 5 tuples (one parens is for the fields list) self.assertEqual(sum(c == '(' for c in cur.query.decode('ascii')), 6) cur.execute("select id, val from testfast order by id") self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)]) def test_unicode(self): cur = self.conn.cursor() ext.register_type(ext.UNICODE, cur) snowman = u"\u2603" # unicode in statement psycopg2.extras.execute_values(cur, "insert into testfast (id, data) values %%s -- %s" % snowman, [(1, 'x')]) cur.execute("select id, data from testfast where id = 1") self.assertEqual(cur.fetchone(), (1, 'x')) # unicode in data psycopg2.extras.execute_values(cur, "insert into testfast (id, data) values %s", [(2, snowman)]) cur.execute("select id, data from testfast where id = 2") self.assertEqual(cur.fetchone(), (2, snowman)) # unicode in both psycopg2.extras.execute_values(cur, "insert into testfast (id, data) values %%s -- %s" % snowman, [(3, snowman)]) cur.execute("select id, data from testfast where id = 3") self.assertEqual(cur.fetchone(), (3, snowman)) def test_invalid_sql(self): cur = self.conn.cursor() self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, "insert", []) self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, "insert %s and %s", []) self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, "insert %f", []) self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, "insert %f %s", []) def test_percent_escape(self): cur = self.conn.cursor() psycopg2.extras.execute_values(cur, "insert into testfast (id, data) values %s -- a%%b", [(1, 'hi')]) self.assert_(b'a%%b' not in cur.query) self.assert_(b'a%b' in cur.query) cur.execute("select id, data from testfast") self.assertEqual(cur.fetchall(), [(1, 'hi')]) testutils.decorate_all_tests(TestExecuteValues, testutils.skip_before_postgres(8, 2)) def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__) if __name__ == "__main__": unittest.main()