diff --git a/psycopg/connection.h b/psycopg/connection.h index ec8e8691..f2296a9d 100644 --- a/psycopg/connection.h +++ b/psycopg/connection.h @@ -37,6 +37,7 @@ extern "C" { #endif /* connection status */ +#define CONN_STATUS_SETUP 0 #define CONN_STATUS_READY 1 #define CONN_STATUS_BEGIN 2 #define CONN_STATUS_SYNC 3 diff --git a/psycopg/connection_type.c b/psycopg/connection_type.c index 3fd08a48..ab5b23ed 100644 --- a/psycopg/connection_type.c +++ b/psycopg/connection_type.c @@ -425,6 +425,12 @@ psyco_conn_poll(connectionObject *self) switch (self->status) { + case CONN_STATUS_SETUP: + /* according to libpq documentation the user should start by waiting + for the socket to become writable */ + self->status = CONN_STATUS_ASYNC; + return PyInt_FromLong(PSYCO_POLL_WRITE); + case CONN_STATUS_SEND_DATESTYLE: case CONN_STATUS_SENT_DATESTYLE: case CONN_STATUS_SEND_CLIENT_ENCODING: @@ -686,7 +692,7 @@ connection_setup(connectionObject *self, const char *dsn, long int async) self->notifies = PyList_New(0); self->closed = 0; self->async = async; - self->status = async ? CONN_STATUS_ASYNC : CONN_STATUS_READY; + self->status = async ? CONN_STATUS_SETUP : CONN_STATUS_READY; self->critical = NULL; self->async_cursor = NULL; self->pgconn = NULL; diff --git a/tests/test_async.py b/tests/test_async.py index 6cb91490..de853e7e 100755 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -35,33 +35,30 @@ class AsyncTests(unittest.TestCase): self.sync_conn = psycopg2.connect(tests.dsn) self.conn = psycopg2.connect(tests.dsn, async=True) - state = psycopg2.extensions.POLL_WRITE - while state != psycopg2.extensions.POLL_OK: - if state == psycopg2.extensions.POLL_WRITE: - select.select([], [self.conn.fileno()], []) - elif state == psycopg2.extensions.POLL_READ: - select.select([self.conn.fileno()], [], []) - state = self.conn.poll() + self.wait(self.conn) curs = self.conn.cursor() curs.execute(''' CREATE TEMPORARY TABLE table1 ( id int PRIMARY KEY )''') - self.wait_for_query(curs) + self.wait(curs) def tearDown(self): self.sync_conn.close() self.conn.close() - def wait_for_query(self, cur): - state = cur.poll() - while state != psycopg2.extensions.POLL_OK: - if state == psycopg2.extensions.POLL_READ: - select.select([cur.fileno()], [], []) + def wait(self, pollable): + while True: + state = pollable.poll() + if state == psycopg2.extensions.POLL_OK: + break + elif state == psycopg2.extensions.POLL_READ: + select.select([pollable], [], []) elif state == psycopg2.extensions.POLL_WRITE: - select.select([], [cur.fileno()], []) - state = cur.poll() + select.select([], [pollable], []) + else: + raise Exception("Unexpected result from poll: %r", state) def test_connection_setup(self): cur = self.conn.cursor() @@ -83,7 +80,7 @@ class AsyncTests(unittest.TestCase): cur.execute("select 'a'") self.assertTrue(self.conn.executing()) - self.wait_for_query(cur) + self.wait(cur) self.assertFalse(self.conn.executing()) self.assertEquals(cur.fetchone()[0], "a") @@ -97,7 +94,7 @@ class AsyncTests(unittest.TestCase): return self.assertTrue(self.conn.executing()) - self.wait_for_query(cur) + self.wait(cur) self.assertFalse(self.conn.executing()) self.assertEquals(cur.fetchall()[0][0], '') @@ -114,17 +111,17 @@ class AsyncTests(unittest.TestCase): self.assertRaises(psycopg2.ProgrammingError, cur.callproc, "version") # but after you've waited it should be good - self.wait_for_query(cur) + self.wait(cur) cur.execute("select * from table1") - self.wait_for_query(cur) + self.wait(cur) self.assertEquals(cur.fetchall()[0][0], 1) cur.execute("delete from table1") - self.wait_for_query(cur) + self.wait(cur) cur.execute("select * from table1") - self.wait_for_query(cur) + self.wait(cur) self.assertEquals(cur.fetchone(), None) @@ -136,7 +133,7 @@ class AsyncTests(unittest.TestCase): self.assertRaises(psycopg2.ProgrammingError, cur.fetchall) # but after waiting it should work - self.wait_for_query(cur) + self.wait(cur) self.assertEquals(cur.fetchall()[0][0], "a") def test_rollback_while_async(self): @@ -151,7 +148,7 @@ class AsyncTests(unittest.TestCase): cur = self.conn.cursor() cur.execute("begin") - self.wait_for_query(cur) + self.wait(cur) cur.execute("insert into table1 values (1)") @@ -160,19 +157,19 @@ class AsyncTests(unittest.TestCase): self.assertTrue(self.conn.executing()) # but a manual commit should - self.wait_for_query(cur) + self.wait(cur) cur.execute("commit") - self.wait_for_query(cur) + self.wait(cur) cur.execute("select * from table1") - self.wait_for_query(cur) + self.wait(cur) self.assertEquals(cur.fetchall()[0][0], 1) cur.execute("delete from table1") - self.wait_for_query(cur) + self.wait(cur) cur.execute("select * from table1") - self.wait_for_query(cur) + self.wait(cur) self.assertEquals(cur.fetchone(), None) def test_set_parameters_while_async(self): @@ -206,16 +203,16 @@ class AsyncTests(unittest.TestCase): cur = self.conn.cursor() cur.execute("begin") - self.wait_for_query(cur) + self.wait(cur) cur.execute("insert into table1 values (1), (2), (3)") - self.wait_for_query(cur) + self.wait(cur) cur.execute("select id from table1 order by id") # iteration fails if a query is underway self.assertRaises(psycopg2.ProgrammingError, list, cur) # but after it's done it should work - self.wait_for_query(cur) + self.wait(cur) self.assertEquals(list(cur), [(1, ), (2, ), (3, )]) self.assertFalse(self.conn.executing()) @@ -242,7 +239,7 @@ class AsyncTests(unittest.TestCase): def test_async_scroll(self): cur = self.conn.cursor() cur.execute("insert into table1 values (1), (2), (3)") - self.wait_for_query(cur) + self.wait(cur) cur.execute("select id from table1 order by id") # scroll should fail if a query is underway @@ -250,13 +247,13 @@ class AsyncTests(unittest.TestCase): self.assertTrue(self.conn.executing()) # but after it's done it should work - self.wait_for_query(cur) + self.wait(cur) cur.scroll(1) self.assertEquals(cur.fetchall(), [(2, ), (3, )]) cur = self.conn.cursor() cur.execute("select id from table1 order by id") - self.wait_for_query(cur) + self.wait(cur) cur2 = self.conn.cursor() self.assertRaises(psycopg2.ProgrammingError, cur2.scroll, 1) @@ -265,7 +262,7 @@ class AsyncTests(unittest.TestCase): cur = self.conn.cursor() cur.execute("select id from table1 order by id") - self.wait_for_query(cur) + self.wait(cur) cur.scroll(2) cur.scroll(-1) self.assertEquals(cur.fetchall(), [(2, ), (3, )]) @@ -284,7 +281,7 @@ class AsyncTests(unittest.TestCase): cur.execute("select repeat('a', 10000); select repeat('b', 10000)") # fetch the result - self.wait_for_query(cur) + self.wait(cur) # it should be the result of the second query self.assertEquals(cur.fetchone()[0], "b" * 10000)