mirror of
https://github.com/psycopg/psycopg2.git
synced 2024-11-27 03:13:43 +03:00
1d3a89a0bb
ag -l Copyright | xargs sed -i \ "s/\(.*copyright (C) [0-9]\+\)\(-[0-9]\+\)\?\(.*Psycopg Team.*\)/\1-$(date +%Y)\3/I"
270 lines
10 KiB
Python
Executable File
270 lines
10 KiB
Python
Executable File
#!/usr/bin/env python
|
|
#
|
|
# test_fast_executemany.py - tests for fast executemany implementations
|
|
#
|
|
# Copyright (C) 2017-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
|
|
# Copyright (C) 2020-2021 The Psycopg Team
|
|
#
|
|
# psycopg2 is free software: you can redistribute it and/or modify it
|
|
# under the terms of the GNU Lesser General Public License as published
|
|
# by the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
|
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
|
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
|
# License for more details.
|
|
|
|
from datetime import date
|
|
|
|
from . import testutils
|
|
import unittest
|
|
|
|
import psycopg2
|
|
import psycopg2.extras
|
|
import psycopg2.extensions as ext
|
|
from psycopg2 import sql
|
|
|
|
|
|
class TestPaginate(unittest.TestCase):
|
|
def test_paginate(self):
|
|
def pag(seq):
|
|
return psycopg2.extras._paginate(seq, 100)
|
|
|
|
self.assertEqual(list(pag([])), [])
|
|
self.assertEqual(list(pag([1])), [[1]])
|
|
self.assertEqual(list(pag(range(99))), [list(range(99))])
|
|
self.assertEqual(list(pag(range(100))), [list(range(100))])
|
|
self.assertEqual(list(pag(range(101))), [list(range(100)), [100]])
|
|
self.assertEqual(
|
|
list(pag(range(200))), [list(range(100)), list(range(100, 200))])
|
|
self.assertEqual(
|
|
list(pag(range(1000))),
|
|
[list(range(i * 100, (i + 1) * 100)) for i in range(10)])
|
|
|
|
|
|
class FastExecuteTestMixin:
|
|
def setUp(self):
|
|
super().setUp()
|
|
cur = self.conn.cursor()
|
|
cur.execute("""create table testfast (
|
|
id serial primary key, date date, val int, data text)""")
|
|
|
|
|
|
class TestExecuteBatch(FastExecuteTestMixin, testutils.ConnectingTestCase):
|
|
def test_empty(self):
|
|
cur = self.conn.cursor()
|
|
psycopg2.extras.execute_batch(cur,
|
|
"insert into testfast (id, val) values (%s, %s)",
|
|
[])
|
|
cur.execute("select * from testfast order by id")
|
|
self.assertEqual(cur.fetchall(), [])
|
|
|
|
def test_one(self):
|
|
cur = self.conn.cursor()
|
|
psycopg2.extras.execute_batch(cur,
|
|
"insert into testfast (id, val) values (%s, %s)",
|
|
iter([(1, 10)]))
|
|
cur.execute("select id, val from testfast order by id")
|
|
self.assertEqual(cur.fetchall(), [(1, 10)])
|
|
|
|
def test_tuples(self):
|
|
cur = self.conn.cursor()
|
|
psycopg2.extras.execute_batch(cur,
|
|
"insert into testfast (id, date, val) values (%s, %s, %s)",
|
|
((i, date(2017, 1, i + 1), i * 10) for i in range(10)))
|
|
cur.execute("select id, date, val from testfast order by id")
|
|
self.assertEqual(cur.fetchall(),
|
|
[(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
|
|
|
|
def test_many(self):
|
|
cur = self.conn.cursor()
|
|
psycopg2.extras.execute_batch(cur,
|
|
"insert into testfast (id, val) values (%s, %s)",
|
|
((i, i * 10) for i in range(1000)))
|
|
cur.execute("select id, val from testfast order by id")
|
|
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
|
|
|
|
def test_composed(self):
|
|
cur = self.conn.cursor()
|
|
psycopg2.extras.execute_batch(cur,
|
|
sql.SQL("insert into {0} (id, val) values (%s, %s)")
|
|
.format(sql.Identifier('testfast')),
|
|
((i, i * 10) for i in range(1000)))
|
|
cur.execute("select id, val from testfast order by id")
|
|
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
|
|
|
|
def test_pages(self):
|
|
cur = self.conn.cursor()
|
|
psycopg2.extras.execute_batch(cur,
|
|
"insert into testfast (id, val) values (%s, %s)",
|
|
((i, i * 10) for i in range(25)),
|
|
page_size=10)
|
|
|
|
# last command was 5 statements
|
|
self.assertEqual(sum(c == ';' for c in cur.query.decode('ascii')), 4)
|
|
|
|
cur.execute("select id, val from testfast order by id")
|
|
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)])
|
|
|
|
@testutils.skip_before_postgres(8, 0)
|
|
def test_unicode(self):
|
|
cur = self.conn.cursor()
|
|
ext.register_type(ext.UNICODE, cur)
|
|
snowman = "\u2603"
|
|
|
|
# unicode in statement
|
|
psycopg2.extras.execute_batch(cur,
|
|
"insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman,
|
|
[(1, 'x')])
|
|
cur.execute("select id, data from testfast where id = 1")
|
|
self.assertEqual(cur.fetchone(), (1, 'x'))
|
|
|
|
# unicode in data
|
|
psycopg2.extras.execute_batch(cur,
|
|
"insert into testfast (id, data) values (%s, %s)",
|
|
[(2, snowman)])
|
|
cur.execute("select id, data from testfast where id = 2")
|
|
self.assertEqual(cur.fetchone(), (2, snowman))
|
|
|
|
# unicode in both
|
|
psycopg2.extras.execute_batch(cur,
|
|
"insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman,
|
|
[(3, snowman)])
|
|
cur.execute("select id, data from testfast where id = 3")
|
|
self.assertEqual(cur.fetchone(), (3, snowman))
|
|
|
|
|
|
@testutils.skip_before_postgres(8, 2)
|
|
class TestExecuteValues(FastExecuteTestMixin, testutils.ConnectingTestCase):
|
|
def test_empty(self):
|
|
cur = self.conn.cursor()
|
|
psycopg2.extras.execute_values(cur,
|
|
"insert into testfast (id, val) values %s",
|
|
[])
|
|
cur.execute("select * from testfast order by id")
|
|
self.assertEqual(cur.fetchall(), [])
|
|
|
|
def test_one(self):
|
|
cur = self.conn.cursor()
|
|
psycopg2.extras.execute_values(cur,
|
|
"insert into testfast (id, val) values %s",
|
|
iter([(1, 10)]))
|
|
cur.execute("select id, val from testfast order by id")
|
|
self.assertEqual(cur.fetchall(), [(1, 10)])
|
|
|
|
def test_tuples(self):
|
|
cur = self.conn.cursor()
|
|
psycopg2.extras.execute_values(cur,
|
|
"insert into testfast (id, date, val) values %s",
|
|
((i, date(2017, 1, i + 1), i * 10) for i in range(10)))
|
|
cur.execute("select id, date, val from testfast order by id")
|
|
self.assertEqual(cur.fetchall(),
|
|
[(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
|
|
|
|
def test_dicts(self):
|
|
cur = self.conn.cursor()
|
|
psycopg2.extras.execute_values(cur,
|
|
"insert into testfast (id, date, val) values %s",
|
|
(dict(id=i, date=date(2017, 1, i + 1), val=i * 10, foo="bar")
|
|
for i in range(10)),
|
|
template='(%(id)s, %(date)s, %(val)s)')
|
|
cur.execute("select id, date, val from testfast order by id")
|
|
self.assertEqual(cur.fetchall(),
|
|
[(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
|
|
|
|
def test_many(self):
|
|
cur = self.conn.cursor()
|
|
psycopg2.extras.execute_values(cur,
|
|
"insert into testfast (id, val) values %s",
|
|
((i, i * 10) for i in range(1000)))
|
|
cur.execute("select id, val from testfast order by id")
|
|
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
|
|
|
|
def test_composed(self):
|
|
cur = self.conn.cursor()
|
|
psycopg2.extras.execute_values(cur,
|
|
sql.SQL("insert into {0} (id, val) values %s")
|
|
.format(sql.Identifier('testfast')),
|
|
((i, i * 10) for i in range(1000)))
|
|
cur.execute("select id, val from testfast order by id")
|
|
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
|
|
|
|
def test_pages(self):
|
|
cur = self.conn.cursor()
|
|
psycopg2.extras.execute_values(cur,
|
|
"insert into testfast (id, val) values %s",
|
|
((i, i * 10) for i in range(25)),
|
|
page_size=10)
|
|
|
|
# last statement was 5 tuples (one parens is for the fields list)
|
|
self.assertEqual(sum(c == '(' for c in cur.query.decode('ascii')), 6)
|
|
|
|
cur.execute("select id, val from testfast order by id")
|
|
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)])
|
|
|
|
def test_unicode(self):
|
|
cur = self.conn.cursor()
|
|
ext.register_type(ext.UNICODE, cur)
|
|
snowman = "\u2603"
|
|
|
|
# unicode in statement
|
|
psycopg2.extras.execute_values(cur,
|
|
"insert into testfast (id, data) values %%s -- %s" % snowman,
|
|
[(1, 'x')])
|
|
cur.execute("select id, data from testfast where id = 1")
|
|
self.assertEqual(cur.fetchone(), (1, 'x'))
|
|
|
|
# unicode in data
|
|
psycopg2.extras.execute_values(cur,
|
|
"insert into testfast (id, data) values %s",
|
|
[(2, snowman)])
|
|
cur.execute("select id, data from testfast where id = 2")
|
|
self.assertEqual(cur.fetchone(), (2, snowman))
|
|
|
|
# unicode in both
|
|
psycopg2.extras.execute_values(cur,
|
|
"insert into testfast (id, data) values %%s -- %s" % snowman,
|
|
[(3, snowman)])
|
|
cur.execute("select id, data from testfast where id = 3")
|
|
self.assertEqual(cur.fetchone(), (3, snowman))
|
|
|
|
def test_returning(self):
|
|
cur = self.conn.cursor()
|
|
result = psycopg2.extras.execute_values(cur,
|
|
"insert into testfast (id, val) values %s returning id",
|
|
((i, i * 10) for i in range(25)),
|
|
page_size=10, fetch=True)
|
|
# result contains all returned pages
|
|
self.assertEqual([r[0] for r in result], list(range(25)))
|
|
|
|
def test_invalid_sql(self):
|
|
cur = self.conn.cursor()
|
|
self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
|
|
"insert", [])
|
|
self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
|
|
"insert %s and %s", [])
|
|
self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
|
|
"insert %f", [])
|
|
self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
|
|
"insert %f %s", [])
|
|
|
|
def test_percent_escape(self):
|
|
cur = self.conn.cursor()
|
|
psycopg2.extras.execute_values(cur,
|
|
"insert into testfast (id, data) values %s -- a%%b",
|
|
[(1, 'hi')])
|
|
self.assert_(b'a%%b' not in cur.query)
|
|
self.assert_(b'a%b' in cur.query)
|
|
|
|
cur.execute("select id, data from testfast")
|
|
self.assertEqual(cur.fetchall(), [(1, 'hi')])
|
|
|
|
|
|
def test_suite():
|
|
return unittest.TestLoader().loadTestsFromName(__name__)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|