mirror of
https://github.com/psycopg/psycopg2.git
synced 2024-11-23 01:16:34 +03:00
Add tests for the asynchronous API
This commit is contained in:
parent
91ef0e09ed
commit
eb646f71fa
|
@ -27,6 +27,7 @@ import test_transaction
|
||||||
import types_basic
|
import types_basic
|
||||||
import types_extras
|
import types_extras
|
||||||
import test_lobject
|
import test_lobject
|
||||||
|
import test_async
|
||||||
|
|
||||||
def test_suite():
|
def test_suite():
|
||||||
suite = unittest.TestSuite()
|
suite = unittest.TestSuite()
|
||||||
|
@ -40,6 +41,7 @@ def test_suite():
|
||||||
suite.addTest(types_basic.test_suite())
|
suite.addTest(types_basic.test_suite())
|
||||||
suite.addTest(types_extras.test_suite())
|
suite.addTest(types_extras.test_suite())
|
||||||
suite.addTest(test_lobject.test_suite())
|
suite.addTest(test_lobject.test_suite())
|
||||||
|
suite.addTest(test_async.test_suite())
|
||||||
return suite
|
return suite
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
252
tests/test_async.py
Normal file
252
tests/test_async.py
Normal file
|
@ -0,0 +1,252 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import psycopg2
|
||||||
|
from psycopg2 import extensions
|
||||||
|
|
||||||
|
import select
|
||||||
|
import StringIO
|
||||||
|
|
||||||
|
import sys
|
||||||
|
if sys.version_info < (3,):
|
||||||
|
import tests
|
||||||
|
else:
|
||||||
|
import py3tests as tests
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncTests(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.sync_conn = psycopg2.connect(tests.dsn)
|
||||||
|
self.conn = psycopg2.connect(tests.dsn, async=True)
|
||||||
|
|
||||||
|
state = psycopg2.extensions.POLL_WRITE
|
||||||
|
while state != psycopg2.extensions.POLL_OK:
|
||||||
|
if state == psycopg2.extensions.POLL_WRITE:
|
||||||
|
select.select([], [self.conn.fileno()], [])
|
||||||
|
elif state == psycopg2.extensions.POLL_READ:
|
||||||
|
select.select([self.conn.fileno()], [], [])
|
||||||
|
state = self.conn.poll()
|
||||||
|
|
||||||
|
curs = self.conn.cursor()
|
||||||
|
curs.execute('''
|
||||||
|
CREATE TEMPORARY TABLE table1 (
|
||||||
|
id int PRIMARY KEY
|
||||||
|
)''')
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.sync_conn.close()
|
||||||
|
self.conn.close()
|
||||||
|
|
||||||
|
def wait_for_query(self, cur):
|
||||||
|
state = cur.poll()
|
||||||
|
while state != psycopg2.extensions.POLL_OK:
|
||||||
|
if state == psycopg2.extensions.POLL_READ:
|
||||||
|
select.select([cur.fileno()], [], [])
|
||||||
|
elif state == psycopg2.extensions.POLL_WRITE:
|
||||||
|
select.select([], [cur.fileno()], [])
|
||||||
|
state = cur.poll()
|
||||||
|
|
||||||
|
def test_wrong_execution_type(self):
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
sync_cur = self.sync_conn.cursor()
|
||||||
|
|
||||||
|
self.assertRaises(psycopg2.ProgrammingError, cur.execute,
|
||||||
|
"select 'a'", async=False)
|
||||||
|
self.assertRaises(psycopg2.ProgrammingError, sync_cur.execute,
|
||||||
|
"select 'a'", async=True)
|
||||||
|
|
||||||
|
# but this should work anyway
|
||||||
|
sync_cur.execute("select 'a'", async=False)
|
||||||
|
cur.execute("select 'a'", async=True)
|
||||||
|
|
||||||
|
def test_async_select(self):
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
self.assertFalse(self.conn.executing())
|
||||||
|
cur.execute("select 'a'")
|
||||||
|
self.assertTrue(self.conn.executing())
|
||||||
|
|
||||||
|
self.wait_for_query(cur)
|
||||||
|
|
||||||
|
self.assertFalse(self.conn.executing())
|
||||||
|
self.assertEquals(cur.fetchone()[0], "a")
|
||||||
|
|
||||||
|
def test_async_callproc(self):
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
try:
|
||||||
|
cur.callproc("pg_sleep", (0.1, ), True)
|
||||||
|
except psycopg2.ProgrammingError:
|
||||||
|
# PG <8.1 did not have pg_sleep
|
||||||
|
return
|
||||||
|
self.assertTrue(self.conn.executing())
|
||||||
|
|
||||||
|
self.wait_for_query(cur)
|
||||||
|
self.assertFalse(self.conn.executing())
|
||||||
|
self.assertEquals(cur.fetchall()[0][0], '')
|
||||||
|
|
||||||
|
def test_async_after_async(self):
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
cur2 = self.conn.cursor()
|
||||||
|
|
||||||
|
cur.execute("insert into table1 values (1)")
|
||||||
|
|
||||||
|
# an async execute after an async one blocks and waits for completion
|
||||||
|
cur.execute("select * from table1")
|
||||||
|
self.wait_for_query(cur)
|
||||||
|
|
||||||
|
self.assertEquals(cur.fetchall()[0][0], 1)
|
||||||
|
|
||||||
|
cur.execute("delete from table1")
|
||||||
|
self.wait_for_query(cur)
|
||||||
|
|
||||||
|
cur.execute("select * from table1")
|
||||||
|
self.wait_for_query(cur)
|
||||||
|
|
||||||
|
self.assertEquals(cur.fetchone(), None)
|
||||||
|
|
||||||
|
def test_fetch_after_async(self):
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
cur.execute("select 'a'")
|
||||||
|
|
||||||
|
# a fetch after an asynchronous query blocks and waits for completion
|
||||||
|
self.assertEquals(cur.fetchall()[0][0], "a")
|
||||||
|
|
||||||
|
def test_rollback_while_async(self):
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
|
||||||
|
cur.execute("select 'a'")
|
||||||
|
|
||||||
|
# a rollback blocks and should leave the connection in a workable state
|
||||||
|
self.conn.rollback()
|
||||||
|
self.assertFalse(self.conn.executing())
|
||||||
|
|
||||||
|
# try a sync cursor first
|
||||||
|
sync_cur = self.sync_conn.cursor()
|
||||||
|
sync_cur.execute("select 'b'")
|
||||||
|
self.assertEquals(sync_cur.fetchone()[0], "b")
|
||||||
|
|
||||||
|
# now try the async cursor
|
||||||
|
cur.execute("select 'c'")
|
||||||
|
self.wait_for_query(cur)
|
||||||
|
self.assertEquals(cur.fetchmany()[0][0], "c")
|
||||||
|
|
||||||
|
def test_commit_while_async(self):
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
|
||||||
|
cur.execute("insert into table1 values (1)")
|
||||||
|
|
||||||
|
# a commit blocks
|
||||||
|
self.conn.commit()
|
||||||
|
self.assertFalse(self.conn.executing())
|
||||||
|
|
||||||
|
cur.execute("select * from table1")
|
||||||
|
self.wait_for_query(cur)
|
||||||
|
self.assertEquals(cur.fetchall()[0][0], 1)
|
||||||
|
|
||||||
|
cur.execute("delete from table1")
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
cur.execute("select * from table1")
|
||||||
|
self.wait_for_query(cur)
|
||||||
|
self.assertEquals(cur.fetchone(), None)
|
||||||
|
|
||||||
|
def test_set_parameters_while_async(self):
|
||||||
|
prev_encoding = self.conn.encoding
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
|
||||||
|
cur.execute("select 'c'")
|
||||||
|
self.assertTrue(self.conn.executing())
|
||||||
|
|
||||||
|
# getting transaction status works
|
||||||
|
self.assertEquals(self.conn.get_transaction_status(),
|
||||||
|
extensions.TRANSACTION_STATUS_ACTIVE)
|
||||||
|
self.assertTrue(self.conn.executing())
|
||||||
|
|
||||||
|
# this issues a ROLLBACK internally
|
||||||
|
self.conn.set_client_encoding("LATIN1")
|
||||||
|
|
||||||
|
self.assertFalse(self.conn.executing())
|
||||||
|
self.assertEquals(self.conn.encoding, "LATIN1")
|
||||||
|
|
||||||
|
self.conn.set_client_encoding(prev_encoding)
|
||||||
|
|
||||||
|
def test_reset_while_async(self):
|
||||||
|
prev_encoding = self.conn.encoding
|
||||||
|
# pick something different than the current encoding
|
||||||
|
new_encoding = (prev_encoding == "LATIN1") and "UTF8" or "LATIN1"
|
||||||
|
|
||||||
|
self.conn.set_client_encoding(new_encoding)
|
||||||
|
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
cur.execute("select 'c'")
|
||||||
|
self.assertTrue(self.conn.executing())
|
||||||
|
|
||||||
|
self.conn.reset()
|
||||||
|
self.assertFalse(self.conn.executing())
|
||||||
|
self.assertEquals(self.conn.encoding, prev_encoding)
|
||||||
|
|
||||||
|
def test_async_iter(self):
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
|
||||||
|
cur.execute("insert into table1 values (1), (2), (3)")
|
||||||
|
self.wait_for_query(cur)
|
||||||
|
cur.execute("select id from table1 order by id")
|
||||||
|
|
||||||
|
# iteration just blocks
|
||||||
|
self.assertEquals(list(cur), [(1, ), (2, ), (3, )])
|
||||||
|
self.assertFalse(self.conn.executing())
|
||||||
|
|
||||||
|
def test_copy_while_async(self):
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
cur.execute("select 'a'")
|
||||||
|
|
||||||
|
# copy just blocks
|
||||||
|
cur.copy_from(StringIO.StringIO("1\n3\n5\n\\.\n"), "table1")
|
||||||
|
|
||||||
|
cur.execute("select * from table1 order by id")
|
||||||
|
self.assertEquals(cur.fetchall(), [(1, ), (3, ), (5, )])
|
||||||
|
|
||||||
|
def test_async_executemany(self):
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
self.assertRaises(
|
||||||
|
psycopg2.ProgrammingError,
|
||||||
|
cur.executemany, "insert into table1 values (%s)", [1, 2, 3])
|
||||||
|
|
||||||
|
def test_async_scroll(self):
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
cur.execute("insert into table1 values (1), (2), (3)")
|
||||||
|
self.wait_for_query(cur)
|
||||||
|
cur.execute("select id from table1 order by id")
|
||||||
|
|
||||||
|
# scroll blocks, but should work
|
||||||
|
cur.scroll(1)
|
||||||
|
self.assertFalse(self.conn.executing())
|
||||||
|
self.assertEquals(cur.fetchall(), [(2, ), (3, )])
|
||||||
|
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
cur.execute("select id from table1 order by id")
|
||||||
|
|
||||||
|
cur2 = self.conn.cursor()
|
||||||
|
self.assertRaises(psycopg2.ProgrammingError, cur2.scroll, 1)
|
||||||
|
|
||||||
|
self.assertRaises(psycopg2.ProgrammingError, cur.scroll, 4)
|
||||||
|
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
cur.execute("select id from table1 order by id")
|
||||||
|
cur.scroll(2)
|
||||||
|
cur.scroll(-1)
|
||||||
|
self.assertEquals(cur.fetchall(), [(2, ), (3, )])
|
||||||
|
|
||||||
|
def test_async_dont_read_all(self):
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
cur.execute("select 'a'; select 'b'")
|
||||||
|
|
||||||
|
# fetch the result
|
||||||
|
self.wait_for_query(cur)
|
||||||
|
|
||||||
|
# it should be the result of the second query
|
||||||
|
self.assertEquals(cur.fetchone()[0][0], "b")
|
||||||
|
|
||||||
|
def test_suite():
|
||||||
|
return unittest.TestLoader().loadTestsFromName(__name__)
|
Loading…
Reference in New Issue
Block a user