From eb646f71fade043520d1ce2dc3b0ac1b3afdee79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Urba=C5=84ski?= Date: Fri, 26 Mar 2010 04:03:54 +0100 Subject: [PATCH] Add tests for the asynchronous API --- tests/__init__.py | 2 + tests/test_async.py | 252 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 254 insertions(+) create mode 100644 tests/test_async.py diff --git a/tests/__init__.py b/tests/__init__.py index a5d7118f..47aedc8b 100755 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -27,6 +27,7 @@ import test_transaction import types_basic import types_extras import test_lobject +import test_async def test_suite(): suite = unittest.TestSuite() @@ -40,6 +41,7 @@ def test_suite(): suite.addTest(types_basic.test_suite()) suite.addTest(types_extras.test_suite()) suite.addTest(test_lobject.test_suite()) + suite.addTest(test_async.test_suite()) return suite if __name__ == '__main__': diff --git a/tests/test_async.py b/tests/test_async.py new file mode 100644 index 00000000..c2bbf4ce --- /dev/null +++ b/tests/test_async.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python +import unittest + +import psycopg2 +from psycopg2 import extensions + +import select +import StringIO + +import sys +if sys.version_info < (3,): + import tests +else: + import py3tests as tests + + +class AsyncTests(unittest.TestCase): + + def setUp(self): + self.sync_conn = psycopg2.connect(tests.dsn) + self.conn = psycopg2.connect(tests.dsn, async=True) + + state = psycopg2.extensions.POLL_WRITE + while state != psycopg2.extensions.POLL_OK: + if state == psycopg2.extensions.POLL_WRITE: + select.select([], [self.conn.fileno()], []) + elif state == psycopg2.extensions.POLL_READ: + select.select([self.conn.fileno()], [], []) + state = self.conn.poll() + + curs = self.conn.cursor() + curs.execute(''' + CREATE TEMPORARY TABLE table1 ( + id int PRIMARY KEY + )''') + self.conn.commit() + + def tearDown(self): + self.sync_conn.close() + self.conn.close() + + def wait_for_query(self, cur): + state = cur.poll() + while state != psycopg2.extensions.POLL_OK: + if state == psycopg2.extensions.POLL_READ: + select.select([cur.fileno()], [], []) + elif state == psycopg2.extensions.POLL_WRITE: + select.select([], [cur.fileno()], []) + state = cur.poll() + + def test_wrong_execution_type(self): + cur = self.conn.cursor() + sync_cur = self.sync_conn.cursor() + + self.assertRaises(psycopg2.ProgrammingError, cur.execute, + "select 'a'", async=False) + self.assertRaises(psycopg2.ProgrammingError, sync_cur.execute, + "select 'a'", async=True) + + # but this should work anyway + sync_cur.execute("select 'a'", async=False) + cur.execute("select 'a'", async=True) + + def test_async_select(self): + cur = self.conn.cursor() + self.assertFalse(self.conn.executing()) + cur.execute("select 'a'") + self.assertTrue(self.conn.executing()) + + self.wait_for_query(cur) + + self.assertFalse(self.conn.executing()) + self.assertEquals(cur.fetchone()[0], "a") + + def test_async_callproc(self): + cur = self.conn.cursor() + try: + cur.callproc("pg_sleep", (0.1, ), True) + except psycopg2.ProgrammingError: + # PG <8.1 did not have pg_sleep + return + self.assertTrue(self.conn.executing()) + + self.wait_for_query(cur) + self.assertFalse(self.conn.executing()) + self.assertEquals(cur.fetchall()[0][0], '') + + def test_async_after_async(self): + cur = self.conn.cursor() + cur2 = self.conn.cursor() + + cur.execute("insert into table1 values (1)") + + # an async execute after an async one blocks and waits for completion + cur.execute("select * from table1") + self.wait_for_query(cur) + + self.assertEquals(cur.fetchall()[0][0], 1) + + cur.execute("delete from table1") + self.wait_for_query(cur) + + cur.execute("select * from table1") + self.wait_for_query(cur) + + self.assertEquals(cur.fetchone(), None) + + def test_fetch_after_async(self): + cur = self.conn.cursor() + cur.execute("select 'a'") + + # a fetch after an asynchronous query blocks and waits for completion + self.assertEquals(cur.fetchall()[0][0], "a") + + def test_rollback_while_async(self): + cur = self.conn.cursor() + + cur.execute("select 'a'") + + # a rollback blocks and should leave the connection in a workable state + self.conn.rollback() + self.assertFalse(self.conn.executing()) + + # try a sync cursor first + sync_cur = self.sync_conn.cursor() + sync_cur.execute("select 'b'") + self.assertEquals(sync_cur.fetchone()[0], "b") + + # now try the async cursor + cur.execute("select 'c'") + self.wait_for_query(cur) + self.assertEquals(cur.fetchmany()[0][0], "c") + + def test_commit_while_async(self): + cur = self.conn.cursor() + + cur.execute("insert into table1 values (1)") + + # a commit blocks + self.conn.commit() + self.assertFalse(self.conn.executing()) + + cur.execute("select * from table1") + self.wait_for_query(cur) + self.assertEquals(cur.fetchall()[0][0], 1) + + cur.execute("delete from table1") + self.conn.commit() + + cur.execute("select * from table1") + self.wait_for_query(cur) + self.assertEquals(cur.fetchone(), None) + + def test_set_parameters_while_async(self): + prev_encoding = self.conn.encoding + cur = self.conn.cursor() + + cur.execute("select 'c'") + self.assertTrue(self.conn.executing()) + + # getting transaction status works + self.assertEquals(self.conn.get_transaction_status(), + extensions.TRANSACTION_STATUS_ACTIVE) + self.assertTrue(self.conn.executing()) + + # this issues a ROLLBACK internally + self.conn.set_client_encoding("LATIN1") + + self.assertFalse(self.conn.executing()) + self.assertEquals(self.conn.encoding, "LATIN1") + + self.conn.set_client_encoding(prev_encoding) + + def test_reset_while_async(self): + prev_encoding = self.conn.encoding + # pick something different than the current encoding + new_encoding = (prev_encoding == "LATIN1") and "UTF8" or "LATIN1" + + self.conn.set_client_encoding(new_encoding) + + cur = self.conn.cursor() + cur.execute("select 'c'") + self.assertTrue(self.conn.executing()) + + self.conn.reset() + self.assertFalse(self.conn.executing()) + self.assertEquals(self.conn.encoding, prev_encoding) + + def test_async_iter(self): + cur = self.conn.cursor() + + cur.execute("insert into table1 values (1), (2), (3)") + self.wait_for_query(cur) + cur.execute("select id from table1 order by id") + + # iteration just blocks + self.assertEquals(list(cur), [(1, ), (2, ), (3, )]) + self.assertFalse(self.conn.executing()) + + def test_copy_while_async(self): + cur = self.conn.cursor() + cur.execute("select 'a'") + + # copy just blocks + cur.copy_from(StringIO.StringIO("1\n3\n5\n\\.\n"), "table1") + + cur.execute("select * from table1 order by id") + self.assertEquals(cur.fetchall(), [(1, ), (3, ), (5, )]) + + def test_async_executemany(self): + cur = self.conn.cursor() + self.assertRaises( + psycopg2.ProgrammingError, + cur.executemany, "insert into table1 values (%s)", [1, 2, 3]) + + def test_async_scroll(self): + cur = self.conn.cursor() + cur.execute("insert into table1 values (1), (2), (3)") + self.wait_for_query(cur) + cur.execute("select id from table1 order by id") + + # scroll blocks, but should work + cur.scroll(1) + self.assertFalse(self.conn.executing()) + self.assertEquals(cur.fetchall(), [(2, ), (3, )]) + + cur = self.conn.cursor() + cur.execute("select id from table1 order by id") + + cur2 = self.conn.cursor() + self.assertRaises(psycopg2.ProgrammingError, cur2.scroll, 1) + + self.assertRaises(psycopg2.ProgrammingError, cur.scroll, 4) + + cur = self.conn.cursor() + cur.execute("select id from table1 order by id") + cur.scroll(2) + cur.scroll(-1) + self.assertEquals(cur.fetchall(), [(2, ), (3, )]) + + def test_async_dont_read_all(self): + cur = self.conn.cursor() + cur.execute("select 'a'; select 'b'") + + # fetch the result + self.wait_for_query(cur) + + # it should be the result of the second query + self.assertEquals(cur.fetchone()[0][0], "b") + +def test_suite(): + return unittest.TestLoader().loadTestsFromName(__name__)