Make the first poll() of an asynchronous connection return POLL_WRITE.

This hides from the user the libpq's implementation detail of
requiring the first select() to wait for the connection socket to
become writable and makes it possible to have a uniform select loop
for both cursors and connections, in which you always start by polling
the object and then acting according to the result from poll().

Idea and implementation by Daniele Varrazzo.
This commit is contained in:
Jan Urbański 2010-04-10 18:54:49 +02:00 committed by Federico Di Gregorio
parent 4afc1baf35
commit 6108e4dc92
3 changed files with 41 additions and 37 deletions

View File

@ -37,6 +37,7 @@ extern "C" {
#endif
/* connection status */
#define CONN_STATUS_SETUP 0
#define CONN_STATUS_READY 1
#define CONN_STATUS_BEGIN 2
#define CONN_STATUS_SYNC 3

View File

@ -425,6 +425,12 @@ psyco_conn_poll(connectionObject *self)
switch (self->status) {
case CONN_STATUS_SETUP:
/* according to libpq documentation the user should start by waiting
for the socket to become writable */
self->status = CONN_STATUS_ASYNC;
return PyInt_FromLong(PSYCO_POLL_WRITE);
case CONN_STATUS_SEND_DATESTYLE:
case CONN_STATUS_SENT_DATESTYLE:
case CONN_STATUS_SEND_CLIENT_ENCODING:
@ -686,7 +692,7 @@ connection_setup(connectionObject *self, const char *dsn, long int async)
self->notifies = PyList_New(0);
self->closed = 0;
self->async = async;
self->status = async ? CONN_STATUS_ASYNC : CONN_STATUS_READY;
self->status = async ? CONN_STATUS_SETUP : CONN_STATUS_READY;
self->critical = NULL;
self->async_cursor = NULL;
self->pgconn = NULL;

View File

@ -35,33 +35,30 @@ class AsyncTests(unittest.TestCase):
self.sync_conn = psycopg2.connect(tests.dsn)
self.conn = psycopg2.connect(tests.dsn, async=True)
state = psycopg2.extensions.POLL_WRITE
while state != psycopg2.extensions.POLL_OK:
if state == psycopg2.extensions.POLL_WRITE:
select.select([], [self.conn.fileno()], [])
elif state == psycopg2.extensions.POLL_READ:
select.select([self.conn.fileno()], [], [])
state = self.conn.poll()
self.wait(self.conn)
curs = self.conn.cursor()
curs.execute('''
CREATE TEMPORARY TABLE table1 (
id int PRIMARY KEY
)''')
self.wait_for_query(curs)
self.wait(curs)
def tearDown(self):
self.sync_conn.close()
self.conn.close()
def wait_for_query(self, cur):
state = cur.poll()
while state != psycopg2.extensions.POLL_OK:
if state == psycopg2.extensions.POLL_READ:
select.select([cur.fileno()], [], [])
def wait(self, pollable):
while True:
state = pollable.poll()
if state == psycopg2.extensions.POLL_OK:
break
elif state == psycopg2.extensions.POLL_READ:
select.select([pollable], [], [])
elif state == psycopg2.extensions.POLL_WRITE:
select.select([], [cur.fileno()], [])
state = cur.poll()
select.select([], [pollable], [])
else:
raise Exception("Unexpected result from poll: %r", state)
def test_connection_setup(self):
cur = self.conn.cursor()
@ -83,7 +80,7 @@ class AsyncTests(unittest.TestCase):
cur.execute("select 'a'")
self.assertTrue(self.conn.executing())
self.wait_for_query(cur)
self.wait(cur)
self.assertFalse(self.conn.executing())
self.assertEquals(cur.fetchone()[0], "a")
@ -97,7 +94,7 @@ class AsyncTests(unittest.TestCase):
return
self.assertTrue(self.conn.executing())
self.wait_for_query(cur)
self.wait(cur)
self.assertFalse(self.conn.executing())
self.assertEquals(cur.fetchall()[0][0], '')
@ -114,17 +111,17 @@ class AsyncTests(unittest.TestCase):
self.assertRaises(psycopg2.ProgrammingError,
cur.callproc, "version")
# but after you've waited it should be good
self.wait_for_query(cur)
self.wait(cur)
cur.execute("select * from table1")
self.wait_for_query(cur)
self.wait(cur)
self.assertEquals(cur.fetchall()[0][0], 1)
cur.execute("delete from table1")
self.wait_for_query(cur)
self.wait(cur)
cur.execute("select * from table1")
self.wait_for_query(cur)
self.wait(cur)
self.assertEquals(cur.fetchone(), None)
@ -136,7 +133,7 @@ class AsyncTests(unittest.TestCase):
self.assertRaises(psycopg2.ProgrammingError,
cur.fetchall)
# but after waiting it should work
self.wait_for_query(cur)
self.wait(cur)
self.assertEquals(cur.fetchall()[0][0], "a")
def test_rollback_while_async(self):
@ -151,7 +148,7 @@ class AsyncTests(unittest.TestCase):
cur = self.conn.cursor()
cur.execute("begin")
self.wait_for_query(cur)
self.wait(cur)
cur.execute("insert into table1 values (1)")
@ -160,19 +157,19 @@ class AsyncTests(unittest.TestCase):
self.assertTrue(self.conn.executing())
# but a manual commit should
self.wait_for_query(cur)
self.wait(cur)
cur.execute("commit")
self.wait_for_query(cur)
self.wait(cur)
cur.execute("select * from table1")
self.wait_for_query(cur)
self.wait(cur)
self.assertEquals(cur.fetchall()[0][0], 1)
cur.execute("delete from table1")
self.wait_for_query(cur)
self.wait(cur)
cur.execute("select * from table1")
self.wait_for_query(cur)
self.wait(cur)
self.assertEquals(cur.fetchone(), None)
def test_set_parameters_while_async(self):
@ -206,16 +203,16 @@ class AsyncTests(unittest.TestCase):
cur = self.conn.cursor()
cur.execute("begin")
self.wait_for_query(cur)
self.wait(cur)
cur.execute("insert into table1 values (1), (2), (3)")
self.wait_for_query(cur)
self.wait(cur)
cur.execute("select id from table1 order by id")
# iteration fails if a query is underway
self.assertRaises(psycopg2.ProgrammingError, list, cur)
# but after it's done it should work
self.wait_for_query(cur)
self.wait(cur)
self.assertEquals(list(cur), [(1, ), (2, ), (3, )])
self.assertFalse(self.conn.executing())
@ -242,7 +239,7 @@ class AsyncTests(unittest.TestCase):
def test_async_scroll(self):
cur = self.conn.cursor()
cur.execute("insert into table1 values (1), (2), (3)")
self.wait_for_query(cur)
self.wait(cur)
cur.execute("select id from table1 order by id")
# scroll should fail if a query is underway
@ -250,13 +247,13 @@ class AsyncTests(unittest.TestCase):
self.assertTrue(self.conn.executing())
# but after it's done it should work
self.wait_for_query(cur)
self.wait(cur)
cur.scroll(1)
self.assertEquals(cur.fetchall(), [(2, ), (3, )])
cur = self.conn.cursor()
cur.execute("select id from table1 order by id")
self.wait_for_query(cur)
self.wait(cur)
cur2 = self.conn.cursor()
self.assertRaises(psycopg2.ProgrammingError, cur2.scroll, 1)
@ -265,7 +262,7 @@ class AsyncTests(unittest.TestCase):
cur = self.conn.cursor()
cur.execute("select id from table1 order by id")
self.wait_for_query(cur)
self.wait(cur)
cur.scroll(2)
cur.scroll(-1)
self.assertEquals(cur.fetchall(), [(2, ), (3, )])
@ -284,7 +281,7 @@ class AsyncTests(unittest.TestCase):
cur.execute("select repeat('a', 10000); select repeat('b', 10000)")
# fetch the result
self.wait_for_query(cur)
self.wait(cur)
# it should be the result of the second query
self.assertEquals(cur.fetchone()[0], "b" * 10000)